Skip to content

Commit 00f03ae

Browse files
authored
fix(binder): nondeterministic_update should accept not nullable inputs (#17874)
* fix * test Signed-off-by: coldWater <forsaken628@gmail.com> * refine Signed-off-by: coldWater <forsaken628@gmail.com> --------- Signed-off-by: coldWater <forsaken628@gmail.com>
1 parent bf78da4 commit 00f03ae

File tree

8 files changed

+134
-50
lines changed

8 files changed

+134
-50
lines changed

src/query/expression/src/evaluator.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ impl<'a> Evaluator<'a> {
125125
&column.data_type,
126126
data_type,
127127
"column data type mismatch at index: {index}, expr: {}",
128-
expr.sql_display(),
128+
expr.fmt_with_options(true)
129129
);
130130
}
131131
}

src/query/sql/src/planner/binder/bind_mutation/update.rs

+85-36
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,21 @@ use crate::binder::bind_mutation::bind::MutationStrategy;
2828
use crate::binder::bind_mutation::mutation_expression::MutationExpression;
2929
use crate::binder::util::TableIdentifier;
3030
use crate::binder::Binder;
31-
use crate::optimizer::ir::SExpr;
31+
use crate::optimizer::ir::Matcher;
3232
use crate::plans::AggregateFunction;
3333
use crate::plans::BoundColumnRef;
34+
use crate::plans::EvalScalar;
3435
use crate::plans::Plan;
36+
use crate::plans::RelOp;
3537
use crate::plans::RelOperator;
3638
use crate::plans::ScalarItem;
3739
use crate::plans::VisitorMut;
3840
use crate::BindContext;
41+
use crate::ColumnBinding;
42+
use crate::ColumnBindingBuilder;
43+
use crate::IndexType;
3944
use crate::ScalarExpr;
45+
use crate::Visibility;
4046

4147
impl Binder {
4248
#[async_backtrace::framed]
@@ -118,17 +124,19 @@ impl Binder {
118124
let Plan::DataMutation { box s_expr, .. } = &plan else {
119125
return Ok(plan);
120126
};
121-
let RelOperator::Mutation(mutation) = &*s_expr.plan else {
127+
let RelOperator::Mutation(mutation) = s_expr.plan() else {
122128
return Ok(plan);
123129
};
124-
let filter_expr = &s_expr.children[0];
125-
let RelOperator::Filter(_) = &*filter_expr.plan else {
126-
return Ok(plan);
130+
let matcher = Matcher::MatchOp {
131+
op_type: RelOp::Filter,
132+
children: vec![Matcher::MatchOp {
133+
op_type: RelOp::Join,
134+
children: vec![Matcher::Leaf, Matcher::Leaf],
135+
}],
127136
};
128-
let input = &filter_expr.children[0];
129-
let RelOperator::Join(_) = &*input.plan else {
137+
if !matcher.matches(s_expr.unary_child()) {
130138
return Ok(plan);
131-
};
139+
}
132140

133141
let mut mutation = mutation.clone();
134142

@@ -176,15 +184,20 @@ impl Binder {
176184
.flat_map(|expr| expr.used_columns().into_iter())
177185
})
178186
})
179-
.chain(mutation.required_columns.iter().copied())
180187
.collect::<HashSet<_>>();
181188

182189
let used_columns = used_columns
183190
.difference(&fields_bindings.iter().map(|column| column.index).collect())
184191
.copied()
185192
.collect::<HashSet<_>>();
186193

187-
let aggr_columns = used_columns
194+
struct AnyColumn {
195+
old: IndexType,
196+
new: IndexType,
197+
cast: Option<ScalarExpr>,
198+
}
199+
200+
let mut any_columns = used_columns
188201
.iter()
189202
.copied()
190203
.filter_map(|i| {
@@ -201,7 +214,7 @@ impl Binder {
201214

202215
let display_name = format!("any({})", binding.index);
203216
let old = binding.index;
204-
let mut aggr_func = ScalarExpr::AggregateFunction(AggregateFunction {
217+
let mut any_func: ScalarExpr = AggregateFunction {
205218
span: None,
206219
func_name: "any".to_string(),
207220
distinct: false,
@@ -210,14 +223,15 @@ impl Binder {
210223
span: None,
211224
column: binding.clone(),
212225
})],
213-
return_type: binding.data_type.clone(),
226+
return_type: Box::new(binding.data_type.wrap_nullable()),
214227
sort_descs: vec![],
215228
display_name: display_name.clone(),
216-
});
229+
}
230+
.into();
217231

218232
let mut rewriter =
219233
AggregateRewriter::new(&mut mutation.bind_context, self.metadata.clone());
220-
rewriter.visit(&mut aggr_func).unwrap();
234+
rewriter.visit(&mut any_func).unwrap();
221235

222236
let new = mutation
223237
.bind_context
@@ -226,10 +240,48 @@ impl Binder {
226240
.unwrap()
227241
.index;
228242

229-
Some((aggr_func, old, new))
243+
let (cast, new) = if !binding.data_type.is_nullable() {
244+
let ColumnBinding {
245+
column_name,
246+
data_type,
247+
..
248+
} = binding;
249+
250+
let column = ColumnBindingBuilder::new(
251+
column_name.clone(),
252+
new,
253+
data_type.clone(),
254+
Visibility::Visible,
255+
)
256+
.build();
257+
let column = ScalarExpr::BoundColumnRef(BoundColumnRef { span: None, column });
258+
let cast = column.unify_to_data_type(&data_type);
259+
260+
let index = self.metadata.write().add_derived_column(
261+
column_name,
262+
*data_type,
263+
Some(cast.clone()),
264+
);
265+
(Some(cast), index)
266+
} else {
267+
(None, new)
268+
};
269+
270+
Some(AnyColumn { old, new, cast })
230271
})
231272
.collect::<Vec<_>>();
232273

274+
let items = any_columns
275+
.iter_mut()
276+
.filter_map(|col| {
277+
col.cast.take().map(|scalar| ScalarItem {
278+
scalar,
279+
index: col.new,
280+
})
281+
})
282+
.collect();
283+
let eval_scalar = EvalScalar { items };
284+
233285
mutation.bind_context.aggregate_info.group_items = fields_bindings
234286
.into_iter()
235287
.chain(std::iter::once(row_id))
@@ -241,55 +293,52 @@ impl Binder {
241293

242294
for eval in &mut mutation.matched_evaluators {
243295
if let Some(expr) = &mut eval.condition {
244-
for (_, old, new) in &aggr_columns {
245-
expr.replace_column(*old, *new)?
296+
for col in &any_columns {
297+
expr.replace_column(col.old, col.new)?
246298
}
247299
}
248300

249301
if let Some(update) = &mut eval.update {
250302
for (_, expr) in update.iter_mut() {
251-
for (_, old, new) in &aggr_columns {
252-
expr.replace_column(*old, *new)?
303+
for col in &any_columns {
304+
expr.replace_column(col.old, col.new)?
253305
}
254306
}
255307
}
256308
}
257309

258310
for (_, column) in mutation.field_index_map.iter_mut() {
259-
if let Some((_, _, index)) = aggr_columns
260-
.iter()
261-
.find(|(_, i, _)| i.to_string() == *column)
262-
{
263-
*column = index.to_string()
311+
if let Some(col) = any_columns.iter().find(|c| c.old.to_string() == *column) {
312+
*column = col.new.to_string()
264313
};
265314
}
266315

267316
mutation.required_columns = Box::new(
268317
std::iter::once(mutation.row_id_index)
269-
.chain(aggr_columns.into_iter().map(|(_, _, i)| i))
318+
.chain(any_columns.into_iter().map(|c| c.new))
270319
.collect(),
271320
);
272321

273-
let aggr_expr = self.bind_aggregate(&mut mutation.bind_context, (**filter_expr).clone())?;
322+
let aggr_expr =
323+
self.bind_aggregate(&mut mutation.bind_context, s_expr.unary_child().clone())?;
274324

275-
let s_expr = SExpr::create_unary(
276-
Arc::new(RelOperator::Mutation(mutation)),
277-
Arc::new(aggr_expr),
278-
);
325+
let input = if eval_scalar.items.is_empty() {
326+
aggr_expr
327+
} else {
328+
aggr_expr.build_unary(Arc::new(eval_scalar.into()))
329+
};
279330

331+
let s_expr = Box::new(input.build_unary(Arc::new(mutation.into())));
280332
let Plan::DataMutation {
281333
schema, metadata, ..
282334
} = plan
283335
else {
284336
unreachable!()
285337
};
286-
287-
let plan = Plan::DataMutation {
288-
s_expr: Box::new(s_expr),
338+
Ok(Plan::DataMutation {
339+
s_expr,
289340
schema,
290341
metadata,
291-
};
292-
293-
Ok(plan)
342+
})
294343
}
295344
}

src/query/sql/src/planner/optimizer/ir/expr/s_expr.rs

+24-8
Original file line numberDiff line numberDiff line change
@@ -68,35 +68,51 @@ pub struct SExpr {
6868
impl SExpr {
6969
pub fn create(
7070
plan: Arc<RelOperator>,
71-
children: impl IntoIterator<Item = Arc<SExpr>>,
71+
children: impl Into<Vec<Arc<SExpr>>>,
7272
original_group: Option<IndexType>,
7373
rel_prop: Option<Arc<RelationalProperty>>,
7474
stat_info: Option<Arc<StatInfo>>,
7575
) -> Self {
7676
SExpr {
7777
plan,
78-
children: children.into_iter().collect(),
78+
children: children.into(),
7979
original_group,
8080
rel_prop: Arc::new(Mutex::new(rel_prop)),
8181
stat_info: Arc::new(Mutex::new(stat_info)),
8282
applied_rules: AppliedRules::default(),
8383
}
8484
}
8585

86-
pub fn create_unary(plan: Arc<RelOperator>, child: Arc<SExpr>) -> Self {
87-
Self::create(plan, vec![child], None, None, None)
86+
pub fn create_unary(plan: Arc<RelOperator>, child: impl Into<Arc<SExpr>>) -> Self {
87+
Self::create(plan, [child.into()], None, None, None)
8888
}
8989

9090
pub fn create_binary(
9191
plan: Arc<RelOperator>,
92-
left_child: Arc<SExpr>,
93-
right_child: Arc<SExpr>,
92+
left_child: impl Into<Arc<SExpr>>,
93+
right_child: impl Into<Arc<SExpr>>,
9494
) -> Self {
95-
Self::create(plan, vec![left_child, right_child], None, None, None)
95+
Self::create(
96+
plan,
97+
[left_child.into(), right_child.into()],
98+
None,
99+
None,
100+
None,
101+
)
96102
}
97103

98104
pub fn create_leaf(plan: Arc<RelOperator>) -> Self {
99-
Self::create(plan, vec![], None, None, None)
105+
Self::create(plan, [], None, None, None)
106+
}
107+
108+
pub fn build_unary(self, plan: Arc<RelOperator>) -> Self {
109+
debug_assert_eq!(plan.arity(), 1);
110+
Self::create(plan, [self.into()], None, None, None)
111+
}
112+
113+
pub fn ref_build_unary(self: &Arc<SExpr>, plan: Arc<RelOperator>) -> Self {
114+
debug_assert_eq!(plan.arity(), 1);
115+
Self::create(plan, [self.clone()], None, None, None)
100116
}
101117

102118
pub fn plan(&self) -> &RelOperator {

src/query/sql/src/planner/optimizer/optimizers/rule/agg_rules/rule_push_down_limit_aggregate.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ impl RulePushDownRankLimitAggregate {
119119
Arc::new(RelOperator::Aggregate(agg_limit)),
120120
Arc::new(agg.child(0)?.clone()),
121121
);
122-
let sort = SExpr::create_unary(Arc::new(RelOperator::Sort(sort)), agg.into());
122+
let sort = SExpr::create_unary(Arc::new(RelOperator::Sort(sort)), agg);
123123
let mut result = s_expr.replace_children(vec![Arc::new(sort)]);
124124

125125
result.set_applied_rule(&self.id);

src/query/sql/src/planner/optimizer/optimizers/rule/filter_rules/rule_push_down_filter_window_top_n.rs

+2-3
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,8 @@ impl Rule for RulePushDownFilterWindowTopN {
127127
s_expr.plan.clone(),
128128
SExpr::create_unary(
129129
window_expr.plan.clone(),
130-
sort_expr.replace_plan(Arc::new(sort.into())).into(),
131-
)
132-
.into(),
130+
sort_expr.replace_plan(Arc::new(sort.into())),
131+
),
133132
);
134133
result.set_applied_rule(&self.id);
135134

tests/sqllogictests/suites/base/03_common/03_0035_update.test

+3
Original file line numberDiff line numberDiff line change
@@ -244,3 +244,6 @@ select * from t;
244244

245245
statement ok
246246
DROP DATABASE db1
247+
248+
statement ok
249+
use default;

tests/sqllogictests/suites/base/03_common/03_0048_nondeterministic_update.test

+17
Original file line numberDiff line numberDiff line change
@@ -69,5 +69,22 @@ include ../issues/issue_15278.test
6969

7070
include ./03_0035_update.test
7171

72+
statement ok
73+
create or replace table test_merge( col1 varchar, col2 varchar, col3 varchar);
74+
75+
statement ok
76+
insert into test_merge values(2,'abc',2),(3,'abc',3),(4,'abc',4);
77+
78+
statement ok
79+
with tbb("col1", "col2", "col3") as (values ('1', 'add', '11'), ('4', 'add', '44'))
80+
update test_merge tba set tba.col1 =tbb.col1, tba.col2 = 'update', tba.col3 = tbb.col3 from tbb where tba.col1 = tbb.col1;
81+
82+
query ITI
83+
select * from test_merge order by col1;
84+
----
85+
2 abc 2
86+
3 abc 3
87+
4 update 44
88+
7289
statement ok
7390
unset error_on_nondeterministic_update;

tests/sqllogictests/suites/base/issues/issue_15278.test

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,4 @@ statement ok
3838
drop table if exists test;
3939

4040
statement ok
41-
drop table if exists test_tmp;
41+
drop table if exists test_tmp;

0 commit comments

Comments
 (0)