@@ -180,60 +180,47 @@ object BulkCopyUtils extends Logging {
180
180
}
181
181
182
182
/**
183
- * getComputedCols
184
- * utility function to get computed columns.
185
- * Use computed column names to exclude computed column when matching schema.
183
+ * getAutoCols
184
+ * utility function to get auto generated columns.
185
+ * Use auto generated column names to exclude them when matching schema.
186
186
*/
187
- private [spark] def getComputedCols (
187
+ private [spark] def getAutoCols (
188
188
conn : Connection ,
189
- table : String ,
190
- hideGraphColumns : Boolean ): List [String ] = {
191
- // TODO can optimize this, also evaluate SQLi issues
192
- val queryStr = if (hideGraphColumns) s """ IF (SERVERPROPERTY('EngineEdition') = 5 OR SERVERPROPERTY('ProductMajorVersion') >= 14)
193
- exec sp_executesql N'SELECT name
194
- FROM sys.computed_columns
195
- WHERE object_id = OBJECT_ID('' ${table}'')
196
- UNION ALL
197
- SELECT C.name
198
- FROM sys.tables AS T
199
- JOIN sys.columns AS C
200
- ON T.object_id = C.object_id
201
- WHERE T.object_id = OBJECT_ID('' ${table}'')
202
- AND (T.is_edge = 1 OR T.is_node = 1)
203
- AND C.is_hidden = 0
204
- AND C.graph_type = 2'
205
- ELSE
206
- SELECT name
207
- FROM sys.computed_columns
189
+ table : String ): List [String ] = {
190
+ // auto cols union computed cols, generated always cols, and node / edge table auto cols
191
+ val queryStr = s """ SELECT name
192
+ FROM sys.columns
208
193
WHERE object_id = OBJECT_ID(' ${table}')
194
+ AND (is_computed = 1 -- computed column
195
+ OR generated_always_type > 0 -- generated always / temporal table
196
+ OR (is_hidden = 0 AND graph_type = 2)) -- graph table
209
197
"""
210
- else s " SELECT name FROM sys.computed_columns WHERE object_id = OBJECT_ID(' ${table}'); "
211
198
212
- val computedColRs = conn.createStatement.executeQuery(queryStr)
213
- val computedCols = ListBuffer [String ]()
214
- while (computedColRs .next()) {
215
- val colName = computedColRs .getString(" name" )
216
- computedCols .append(colName)
199
+ val autoColRs = conn.createStatement.executeQuery(queryStr)
200
+ val autoCols = ListBuffer [String ]()
201
+ while (autoColRs .next()) {
202
+ val colName = autoColRs .getString(" name" )
203
+ autoCols .append(colName)
217
204
}
218
- computedCols .toList
205
+ autoCols .toList
219
206
}
220
207
221
208
/**
222
- * dfComputedColCount
209
+ * dfAutoColCount
223
210
* utility function to get number of computed columns in dataframe.
224
211
* Use number of computed columns in dataframe to get number of non computed column in df,
225
212
* and compare with the number of non computed column in sql table
226
213
*/
227
- private [spark] def dfComputedColCount (
214
+ private [spark] def dfAutoColCount (
228
215
dfColNames : List [String ],
229
- computedCols : List [String ],
216
+ autoCols : List [String ],
230
217
dfColCaseMap : Map [String , String ],
231
218
isCaseSensitive : Boolean ): Int = {
232
219
var dfComputedColCt = 0
233
- for (j <- 0 to computedCols .length- 1 ){
234
- if (isCaseSensitive && dfColNames.contains(computedCols (j)) ||
235
- ! isCaseSensitive && dfColCaseMap.contains(computedCols (j).toLowerCase())
236
- && dfColCaseMap(computedCols (j).toLowerCase()) == computedCols (j)) {
220
+ for (j <- 0 to autoCols .length- 1 ){
221
+ if (isCaseSensitive && dfColNames.contains(autoCols (j)) ||
222
+ ! isCaseSensitive && dfColCaseMap.contains(autoCols (j).toLowerCase())
223
+ && dfColCaseMap(autoCols (j).toLowerCase()) == autoCols (j)) {
237
224
dfComputedColCt += 1
238
225
}
239
226
}
@@ -284,7 +271,7 @@ SELECT name
284
271
val colMetaData = {
285
272
if (checkSchema) {
286
273
checkExTableType(conn, options)
287
- matchSchemas(conn, options.dbtable, df, rs, options.url, isCaseSensitive, options.schemaCheckEnabled, options.hideGraphColumns )
274
+ matchSchemas(conn, options.dbtable, df, rs, options.url, isCaseSensitive, options.schemaCheckEnabled)
288
275
} else {
289
276
defaultColMetadataMap(rs.getMetaData())
290
277
}
@@ -310,7 +297,6 @@ SELECT name
310
297
* @param url: String,
311
298
* @param isCaseSensitive: Boolean
312
299
* @param strictSchemaCheck: Boolean
313
- * @param hideGraphColumns - Whether to hide the $node_id, $from_id, $to_id, $edge_id columns in SQL graph tables
314
300
*/
315
301
private [spark] def matchSchemas (
316
302
conn : Connection ,
@@ -319,40 +305,39 @@ SELECT name
319
305
rs : ResultSet ,
320
306
url : String ,
321
307
isCaseSensitive : Boolean ,
322
- strictSchemaCheck : Boolean ,
323
- hideGraphColumns : Boolean ): Array [ColumnMetadata ]= {
308
+ strictSchemaCheck : Boolean ): Array [ColumnMetadata ]= {
324
309
val dfColCaseMap = (df.schema.fieldNames.map(item => item.toLowerCase)
325
310
zip df.schema.fieldNames.toList).toMap
326
311
val dfCols = df.schema
327
312
328
313
val tableCols = getSchema(rs, JdbcDialects .get(url))
329
- val computedCols = getComputedCols (conn, dbtable, hideGraphColumns )
314
+ val autoCols = getAutoCols (conn, dbtable)
330
315
331
316
val prefix = " Spark Dataframe and SQL Server table have differing"
332
317
333
- if (computedCols .length == 0 ) {
318
+ if (autoCols .length == 0 ) {
334
319
assertIfCheckEnabled(dfCols.length == tableCols.length, strictSchemaCheck,
335
320
s " ${prefix} numbers of columns " )
336
321
} else if (strictSchemaCheck) {
337
322
val dfColNames = df.schema.fieldNames.toList
338
- val dfComputedColCt = dfComputedColCount (dfColNames, computedCols , dfColCaseMap, isCaseSensitive)
323
+ val dfComputedColCt = dfAutoColCount (dfColNames, autoCols , dfColCaseMap, isCaseSensitive)
339
324
// if df has computed column(s), check column length using non computed column in df and table.
340
325
// non computed column number in df: dfCols.length - dfComputedColCt
341
- // non computed column number in table: tableCols.length - computedCols .length
342
- assertIfCheckEnabled(dfCols.length- dfComputedColCt == tableCols.length- computedCols .length, strictSchemaCheck,
326
+ // non computed column number in table: tableCols.length - autoCols .length
327
+ assertIfCheckEnabled(dfCols.length- dfComputedColCt == tableCols.length- autoCols .length, strictSchemaCheck,
343
328
s " ${prefix} numbers of columns " )
344
329
}
345
330
346
331
347
- val result = new Array [ColumnMetadata ](tableCols.length - computedCols .length)
332
+ val result = new Array [ColumnMetadata ](tableCols.length - autoCols .length)
348
333
var nonAutoColIndex = 0
349
334
350
335
for (i <- 0 to tableCols.length- 1 ) {
351
336
val tableColName = tableCols(i).name
352
337
var dfFieldIndex = - 1
353
338
// set dfFieldIndex = -1 for all computed columns to skip ColumnMetadata
354
- if (computedCols .contains(tableColName)) {
355
- logDebug(s " skipping computed col index $i col name $tableColName dfFieldIndex $dfFieldIndex" )
339
+ if (autoCols .contains(tableColName)) {
340
+ logDebug(s " skipping auto generated col index $i col name $tableColName dfFieldIndex $dfFieldIndex" )
356
341
}else {
357
342
var dfColName : String = " "
358
343
if (isCaseSensitive) {
0 commit comments