@@ -28,15 +28,21 @@ use crate::binder::bind_mutation::bind::MutationStrategy;
28
28
use crate :: binder:: bind_mutation:: mutation_expression:: MutationExpression ;
29
29
use crate :: binder:: util:: TableIdentifier ;
30
30
use crate :: binder:: Binder ;
31
- use crate :: optimizer:: ir:: SExpr ;
31
+ use crate :: optimizer:: ir:: Matcher ;
32
32
use crate :: plans:: AggregateFunction ;
33
33
use crate :: plans:: BoundColumnRef ;
34
+ use crate :: plans:: EvalScalar ;
34
35
use crate :: plans:: Plan ;
36
+ use crate :: plans:: RelOp ;
35
37
use crate :: plans:: RelOperator ;
36
38
use crate :: plans:: ScalarItem ;
37
39
use crate :: plans:: VisitorMut ;
38
40
use crate :: BindContext ;
41
+ use crate :: ColumnBinding ;
42
+ use crate :: ColumnBindingBuilder ;
43
+ use crate :: IndexType ;
39
44
use crate :: ScalarExpr ;
45
+ use crate :: Visibility ;
40
46
41
47
impl Binder {
42
48
#[ async_backtrace:: framed]
@@ -118,17 +124,19 @@ impl Binder {
118
124
let Plan :: DataMutation { box s_expr, .. } = & plan else {
119
125
return Ok ( plan) ;
120
126
} ;
121
- let RelOperator :: Mutation ( mutation) = & * s_expr. plan else {
127
+ let RelOperator :: Mutation ( mutation) = s_expr. plan ( ) else {
122
128
return Ok ( plan) ;
123
129
} ;
124
- let filter_expr = & s_expr. children [ 0 ] ;
125
- let RelOperator :: Filter ( _) = & * filter_expr. plan else {
126
- return Ok ( plan) ;
130
+ let matcher = Matcher :: MatchOp {
131
+ op_type : RelOp :: Filter ,
132
+ children : vec ! [ Matcher :: MatchOp {
133
+ op_type: RelOp :: Join ,
134
+ children: vec![ Matcher :: Leaf , Matcher :: Leaf ] ,
135
+ } ] ,
127
136
} ;
128
- let input = & filter_expr. children [ 0 ] ;
129
- let RelOperator :: Join ( _) = & * input. plan else {
137
+ if !matcher. matches ( s_expr. unary_child ( ) ) {
130
138
return Ok ( plan) ;
131
- } ;
139
+ }
132
140
133
141
let mut mutation = mutation. clone ( ) ;
134
142
@@ -176,15 +184,20 @@ impl Binder {
176
184
. flat_map ( |expr| expr. used_columns ( ) . into_iter ( ) )
177
185
} )
178
186
} )
179
- . chain ( mutation. required_columns . iter ( ) . copied ( ) )
180
187
. collect :: < HashSet < _ > > ( ) ;
181
188
182
189
let used_columns = used_columns
183
190
. difference ( & fields_bindings. iter ( ) . map ( |column| column. index ) . collect ( ) )
184
191
. copied ( )
185
192
. collect :: < HashSet < _ > > ( ) ;
186
193
187
- let aggr_columns = used_columns
194
+ struct AnyColumn {
195
+ old : IndexType ,
196
+ new : IndexType ,
197
+ cast : Option < ScalarExpr > ,
198
+ }
199
+
200
+ let mut any_columns = used_columns
188
201
. iter ( )
189
202
. copied ( )
190
203
. filter_map ( |i| {
@@ -201,7 +214,7 @@ impl Binder {
201
214
202
215
let display_name = format ! ( "any({})" , binding. index) ;
203
216
let old = binding. index ;
204
- let mut aggr_func = ScalarExpr :: AggregateFunction ( AggregateFunction {
217
+ let mut any_func : ScalarExpr = AggregateFunction {
205
218
span : None ,
206
219
func_name : "any" . to_string ( ) ,
207
220
distinct : false ,
@@ -210,14 +223,15 @@ impl Binder {
210
223
span: None ,
211
224
column: binding. clone( ) ,
212
225
} ) ] ,
213
- return_type : binding. data_type . clone ( ) ,
226
+ return_type : Box :: new ( binding. data_type . wrap_nullable ( ) ) ,
214
227
sort_descs : vec ! [ ] ,
215
228
display_name : display_name. clone ( ) ,
216
- } ) ;
229
+ }
230
+ . into ( ) ;
217
231
218
232
let mut rewriter =
219
233
AggregateRewriter :: new ( & mut mutation. bind_context , self . metadata . clone ( ) ) ;
220
- rewriter. visit ( & mut aggr_func ) . unwrap ( ) ;
234
+ rewriter. visit ( & mut any_func ) . unwrap ( ) ;
221
235
222
236
let new = mutation
223
237
. bind_context
@@ -226,10 +240,48 @@ impl Binder {
226
240
. unwrap ( )
227
241
. index ;
228
242
229
- Some ( ( aggr_func, old, new) )
243
+ let ( cast, new) = if !binding. data_type . is_nullable ( ) {
244
+ let ColumnBinding {
245
+ column_name,
246
+ data_type,
247
+ ..
248
+ } = binding;
249
+
250
+ let column = ColumnBindingBuilder :: new (
251
+ column_name. clone ( ) ,
252
+ new,
253
+ data_type. clone ( ) ,
254
+ Visibility :: Visible ,
255
+ )
256
+ . build ( ) ;
257
+ let column = ScalarExpr :: BoundColumnRef ( BoundColumnRef { span : None , column } ) ;
258
+ let cast = column. unify_to_data_type ( & data_type) ;
259
+
260
+ let index = self . metadata . write ( ) . add_derived_column (
261
+ column_name,
262
+ * data_type,
263
+ Some ( cast. clone ( ) ) ,
264
+ ) ;
265
+ ( Some ( cast) , index)
266
+ } else {
267
+ ( None , new)
268
+ } ;
269
+
270
+ Some ( AnyColumn { old, new, cast } )
230
271
} )
231
272
. collect :: < Vec < _ > > ( ) ;
232
273
274
+ let items = any_columns
275
+ . iter_mut ( )
276
+ . filter_map ( |col| {
277
+ col. cast . take ( ) . map ( |scalar| ScalarItem {
278
+ scalar,
279
+ index : col. new ,
280
+ } )
281
+ } )
282
+ . collect ( ) ;
283
+ let eval_scalar = EvalScalar { items } ;
284
+
233
285
mutation. bind_context . aggregate_info . group_items = fields_bindings
234
286
. into_iter ( )
235
287
. chain ( std:: iter:: once ( row_id) )
@@ -241,55 +293,52 @@ impl Binder {
241
293
242
294
for eval in & mut mutation. matched_evaluators {
243
295
if let Some ( expr) = & mut eval. condition {
244
- for ( _ , old , new ) in & aggr_columns {
245
- expr. replace_column ( * old, * new) ?
296
+ for col in & any_columns {
297
+ expr. replace_column ( col . old , col . new ) ?
246
298
}
247
299
}
248
300
249
301
if let Some ( update) = & mut eval. update {
250
302
for ( _, expr) in update. iter_mut ( ) {
251
- for ( _ , old , new ) in & aggr_columns {
252
- expr. replace_column ( * old, * new) ?
303
+ for col in & any_columns {
304
+ expr. replace_column ( col . old , col . new ) ?
253
305
}
254
306
}
255
307
}
256
308
}
257
309
258
310
for ( _, column) in mutation. field_index_map . iter_mut ( ) {
259
- if let Some ( ( _, _, index) ) = aggr_columns
260
- . iter ( )
261
- . find ( |( _, i, _) | i. to_string ( ) == * column)
262
- {
263
- * column = index. to_string ( )
311
+ if let Some ( col) = any_columns. iter ( ) . find ( |c| c. old . to_string ( ) == * column) {
312
+ * column = col. new . to_string ( )
264
313
} ;
265
314
}
266
315
267
316
mutation. required_columns = Box :: new (
268
317
std:: iter:: once ( mutation. row_id_index )
269
- . chain ( aggr_columns . into_iter ( ) . map ( |( _ , _ , i ) | i ) )
318
+ . chain ( any_columns . into_iter ( ) . map ( |c| c . new ) )
270
319
. collect ( ) ,
271
320
) ;
272
321
273
- let aggr_expr = self . bind_aggregate ( & mut mutation. bind_context , ( * * filter_expr) . clone ( ) ) ?;
322
+ let aggr_expr =
323
+ self . bind_aggregate ( & mut mutation. bind_context , s_expr. unary_child ( ) . clone ( ) ) ?;
274
324
275
- let s_expr = SExpr :: create_unary (
276
- Arc :: new ( RelOperator :: Mutation ( mutation) ) ,
277
- Arc :: new ( aggr_expr) ,
278
- ) ;
325
+ let input = if eval_scalar. items . is_empty ( ) {
326
+ aggr_expr
327
+ } else {
328
+ aggr_expr. build_unary ( Arc :: new ( eval_scalar. into ( ) ) )
329
+ } ;
279
330
331
+ let s_expr = Box :: new ( input. build_unary ( Arc :: new ( mutation. into ( ) ) ) ) ;
280
332
let Plan :: DataMutation {
281
333
schema, metadata, ..
282
334
} = plan
283
335
else {
284
336
unreachable ! ( )
285
337
} ;
286
-
287
- let plan = Plan :: DataMutation {
288
- s_expr : Box :: new ( s_expr) ,
338
+ Ok ( Plan :: DataMutation {
339
+ s_expr,
289
340
schema,
290
341
metadata,
291
- } ;
292
-
293
- Ok ( plan)
342
+ } )
294
343
}
295
344
}
0 commit comments