diff --git a/src/lazymacro.jl b/src/lazymacro.jl index 521c0d66..c522ecd4 100644 --- a/src/lazymacro.jl +++ b/src/lazymacro.jl @@ -15,10 +15,17 @@ Broadcast.materialize(x::LazyCast) = x.value is_call(ex) = isexpr(ex, :call) && !is_dotcall(ex) -is_dotcall(ex) = - (isexpr(ex, :.) && isexpr(ex.args[2], :tuple)) || - (isexpr(ex, :call) && ex.args[1] isa Symbol && startswith(String(ex.args[1]), ".")) -# e.g., `f.(x, y, z)` or `x .+ y .+ z` +is_dotcall(ex) = is_dotcall_fn(ex) || is_dotcall_op(ex) + +is_dotcall_fn(ex) = (isexpr(ex, :.) && isexpr(ex.args[2], :tuple)) +# e.g., `f.(x, y, z)` + +function is_dotcall_op(ex) + isexpr(ex, :call) && !isempty(ex.args) || return false + op = ex.args[1] + return op isa Symbol && Base.isoperator(op) && startswith(string(op), ".") +end +# e.g., `x .+ y .+ z` lazy_expr(x) = x function lazy_expr(ex::Expr) @@ -40,12 +47,12 @@ end bc_expr_impl(x) = x function bc_expr_impl(ex::Expr) # walk down chain of dot calls - if ex.head == :. && ex.args[2].head === :tuple + if is_dotcall_fn(ex) @assert length(ex.args) == 2 # argument is always expressed as a tuple f = ex.args[1] # function name args = ex.args[2].args return Expr(ex.head, lazy_expr(f), Expr(:tuple, bc_expr_impl.(args)...)) - elseif ex.head == :call && startswith(String(ex.args[1]), ".") + elseif is_dotcall_op(ex) f = ex.args[1] # function name (e.g., `.+`) args = ex.args[2:end] return Expr(ex.head, lazy_expr(f), bc_expr_impl.(args)...) @@ -64,6 +71,10 @@ app_expr_impl(x) = x function app_expr_impl(ex::Expr) # walk down chain of calls and lazy-ify them if is_call(ex) + if isexpr(ex.args[1], :$) + # eagerly evaluate the call + return Expr(:call, ex.args[1].args[1], ex.args[2:end]...) + end return :($applied($(app_expr_impl.(ex.args)...))) else return lazy_expr(ex) diff --git a/test/lazymultests.jl b/test/lazymultests.jl index b3851f92..981159e6 100644 --- a/test/lazymultests.jl +++ b/test/lazymultests.jl @@ -162,7 +162,7 @@ LinearAlgebra.factorize(A::MyLazyArray) = factorize(A.data) @test apply(*,A,x) isa ApplyVector @test apply(*,A,Array(x)) isa ApplyVector @test apply(*,Array(A),x) isa ApplyVector - @test apply(*,A,x) ≈ apply(*,Array(A),x) ≈ apply(*,A,Array(x)) ≈ Array(A)*Array(x) + @test apply(*,A,x) ≈ apply(*,Array(A),x) ≈ apply(*,A,Array(x)) ≈ Array(A)*Array(x) @test apply(*,A,B) isa ApplyMatrix @test apply(*,A,Array(B)) isa ApplyMatrix diff --git a/test/macrotests.jl b/test/macrotests.jl index 3c971405..b462087e 100644 --- a/test/macrotests.jl +++ b/test/macrotests.jl @@ -97,4 +97,23 @@ end @test bc.args[1].args isa Tuple{Applied, Int} end +@testset "@~ and \$" begin + A = ones(1, 1) + x = [1] + + # Use `$` to evaluate a sub-expression eagerly + bc = @~ A .+ $Ref(x) + @test bc isa Broadcasted + @test bc.args[1] === A + @test bc.args[2] isa Ref # not Applied + @test bc.args[2][] === x + + # Use `$$` when combined with `@.` + bc = @~ @. A + $$Ref(x) + @test bc isa Broadcasted + @test bc.args[1] === A + @test bc.args[2] isa Ref # not Applied + @test bc.args[2][] === x +end + end # module