Skip to content
This repository was archived by the owner on Feb 27, 2025. It is now read-only.

Commit 07fee04

Browse files
committed
add support for graph table / temporal table
1 parent 1d57c39 commit 07fee04

File tree

2 files changed

+35
-52
lines changed

2 files changed

+35
-52
lines changed

src/main/scala/com/microsoft/sqlserver/jdbc/spark/SQLServerBulkJdbcOptions.scala

+1-3
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,10 @@ class SQLServerBulkJdbcOptions(val params: CaseInsensitiveMap[String])
6868
val allowEncryptedValueModifications =
6969
params.getOrElse("allowEncryptedValueModifications", "false").toBoolean
7070

71+
7172
val schemaCheckEnabled =
7273
params.getOrElse("schemaCheckEnabled", "true").toBoolean
7374

74-
val hideGraphColumns =
75-
params.getOrElse("hideGraphColumns", "true").toBoolean
76-
7775
// Not a feature
7876
// Only used for internally testing data idempotency
7977
val testDataIdempotency =

src/main/scala/com/microsoft/sqlserver/jdbc/spark/utils/BulkCopyUtils.scala

+34-49
Original file line numberDiff line numberDiff line change
@@ -180,60 +180,47 @@ object BulkCopyUtils extends Logging {
180180
}
181181

182182
/**
183-
* getComputedCols
184-
* utility function to get computed columns.
185-
* Use computed column names to exclude computed column when matching schema.
183+
* getAutoCols
184+
* utility function to get auto generated columns.
185+
* Use auto generated column names to exclude them when matching schema.
186186
*/
187-
private[spark] def getComputedCols(
187+
private[spark] def getAutoCols(
188188
conn: Connection,
189-
table: String,
190-
hideGraphColumns: Boolean): List[String] = {
191-
// TODO can optimize this, also evaluate SQLi issues
192-
val queryStr = if (hideGraphColumns) s"""IF (SERVERPROPERTY('EngineEdition') = 5 OR SERVERPROPERTY('ProductMajorVersion') >= 14)
193-
exec sp_executesql N'SELECT name
194-
FROM sys.computed_columns
195-
WHERE object_id = OBJECT_ID(''${table}'')
196-
UNION ALL
197-
SELECT C.name
198-
FROM sys.tables AS T
199-
JOIN sys.columns AS C
200-
ON T.object_id = C.object_id
201-
WHERE T.object_id = OBJECT_ID(''${table}'')
202-
AND (T.is_edge = 1 OR T.is_node = 1)
203-
AND C.is_hidden = 0
204-
AND C.graph_type = 2'
205-
ELSE
206-
SELECT name
207-
FROM sys.computed_columns
189+
table: String): List[String] = {
190+
// auto cols union computed cols, generated always cols, and node / edge table auto cols
191+
val queryStr = s"""SELECT name
192+
FROM sys.columns
208193
WHERE object_id = OBJECT_ID('${table}')
194+
AND (is_computed = 1 -- computed column
195+
OR generated_always_type > 0 -- generated always / temporal table
196+
OR (is_hidden = 0 AND graph_type = 2)) -- graph table
209197
"""
210-
else s"SELECT name FROM sys.computed_columns WHERE object_id = OBJECT_ID('${table}');"
211198

212-
val computedColRs = conn.createStatement.executeQuery(queryStr)
213-
val computedCols = ListBuffer[String]()
214-
while (computedColRs.next()) {
215-
val colName = computedColRs.getString("name")
216-
computedCols.append(colName)
199+
val autoColRs = conn.createStatement.executeQuery(queryStr)
200+
val autoCols = ListBuffer[String]()
201+
while (autoColRs.next()) {
202+
val colName = autoColRs.getString("name")
203+
autoCols.append(colName)
217204
}
218-
computedCols.toList
205+
autoCols.toList
219206
}
220207

221208
/**
222-
* dfComputedColCount
209+
* dfAutoColCount
223210
* utility function to get number of computed columns in dataframe.
224211
* Use number of computed columns in dataframe to get number of non computed column in df,
225212
* and compare with the number of non computed column in sql table
226213
*/
227-
private[spark] def dfComputedColCount(
214+
private[spark] def dfAutoColCount(
228215
dfColNames: List[String],
229-
computedCols: List[String],
216+
autoCols: List[String],
230217
dfColCaseMap: Map[String, String],
231218
isCaseSensitive: Boolean): Int ={
232219
var dfComputedColCt = 0
233-
for (j <- 0 to computedCols.length-1){
234-
if (isCaseSensitive && dfColNames.contains(computedCols(j)) ||
235-
!isCaseSensitive && dfColCaseMap.contains(computedCols(j).toLowerCase())
236-
&& dfColCaseMap(computedCols(j).toLowerCase()) == computedCols(j)) {
220+
for (j <- 0 to autoCols.length-1){
221+
if (isCaseSensitive && dfColNames.contains(autoCols(j)) ||
222+
!isCaseSensitive && dfColCaseMap.contains(autoCols(j).toLowerCase())
223+
&& dfColCaseMap(autoCols(j).toLowerCase()) == autoCols(j)) {
237224
dfComputedColCt += 1
238225
}
239226
}
@@ -284,7 +271,7 @@ SELECT name
284271
val colMetaData = {
285272
if(checkSchema) {
286273
checkExTableType(conn, options)
287-
matchSchemas(conn, options.dbtable, df, rs, options.url, isCaseSensitive, options.schemaCheckEnabled, options.hideGraphColumns)
274+
matchSchemas(conn, options.dbtable, df, rs, options.url, isCaseSensitive, options.schemaCheckEnabled)
288275
} else {
289276
defaultColMetadataMap(rs.getMetaData())
290277
}
@@ -310,7 +297,6 @@ SELECT name
310297
* @param url: String,
311298
* @param isCaseSensitive: Boolean
312299
* @param strictSchemaCheck: Boolean
313-
* @param hideGraphColumns - Whether to hide the $node_id, $from_id, $to_id, $edge_id columns in SQL graph tables
314300
*/
315301
private[spark] def matchSchemas(
316302
conn: Connection,
@@ -319,40 +305,39 @@ SELECT name
319305
rs: ResultSet,
320306
url: String,
321307
isCaseSensitive: Boolean,
322-
strictSchemaCheck: Boolean,
323-
hideGraphColumns: Boolean): Array[ColumnMetadata]= {
308+
strictSchemaCheck: Boolean): Array[ColumnMetadata]= {
324309
val dfColCaseMap = (df.schema.fieldNames.map(item => item.toLowerCase)
325310
zip df.schema.fieldNames.toList).toMap
326311
val dfCols = df.schema
327312

328313
val tableCols = getSchema(rs, JdbcDialects.get(url))
329-
val computedCols = getComputedCols(conn, dbtable, hideGraphColumns)
314+
val autoCols = getAutoCols(conn, dbtable)
330315

331316
val prefix = "Spark Dataframe and SQL Server table have differing"
332317

333-
if (computedCols.length == 0) {
318+
if (autoCols.length == 0) {
334319
assertIfCheckEnabled(dfCols.length == tableCols.length, strictSchemaCheck,
335320
s"${prefix} numbers of columns")
336321
} else if (strictSchemaCheck) {
337322
val dfColNames = df.schema.fieldNames.toList
338-
val dfComputedColCt = dfComputedColCount(dfColNames, computedCols, dfColCaseMap, isCaseSensitive)
323+
val dfComputedColCt = dfAutoColCount(dfColNames, autoCols, dfColCaseMap, isCaseSensitive)
339324
// if df has computed column(s), check column length using non computed column in df and table.
340325
// non computed column number in df: dfCols.length - dfComputedColCt
341-
// non computed column number in table: tableCols.length - computedCols.length
342-
assertIfCheckEnabled(dfCols.length-dfComputedColCt == tableCols.length-computedCols.length, strictSchemaCheck,
326+
// non computed column number in table: tableCols.length - autoCols.length
327+
assertIfCheckEnabled(dfCols.length-dfComputedColCt == tableCols.length-autoCols.length, strictSchemaCheck,
343328
s"${prefix} numbers of columns")
344329
}
345330

346331

347-
val result = new Array[ColumnMetadata](tableCols.length - computedCols.length)
332+
val result = new Array[ColumnMetadata](tableCols.length - autoCols.length)
348333
var nonAutoColIndex = 0
349334

350335
for (i <- 0 to tableCols.length-1) {
351336
val tableColName = tableCols(i).name
352337
var dfFieldIndex = -1
353338
// set dfFieldIndex = -1 for all computed columns to skip ColumnMetadata
354-
if (computedCols.contains(tableColName)) {
355-
logDebug(s"skipping computed col index $i col name $tableColName dfFieldIndex $dfFieldIndex")
339+
if (autoCols.contains(tableColName)) {
340+
logDebug(s"skipping auto generated col index $i col name $tableColName dfFieldIndex $dfFieldIndex")
356341
}else{
357342
var dfColName:String = ""
358343
if (isCaseSensitive) {

0 commit comments

Comments
 (0)