diff --git a/src/Parse.jl b/src/Parse.jl index 10d121db..066d5945 100644 --- a/src/Parse.jl +++ b/src/Parse.jl @@ -200,6 +200,21 @@ end ) end +_replace_imaginary_unit_symbol(ex) = ex +@unstable _replace_imaginary_unit_symbol(ex::Symbol) = ex === :im ? im : ex +function _replace_imaginary_unit_symbol(ex::Expr) + return Expr(ex.head, map(_replace_imaginary_unit_symbol, ex.args)...) +end + +@unstable function _normalize_expression_for_parse( + ex, variable_names::Union{AbstractVector{<:AbstractString},Nothing} +) + if variable_names !== nothing && ("im" in variable_names) + return ex + end + return _replace_imaginary_unit_symbol(ex) +end + """Parse an expression Julia `Expr` object.""" @unstable function parse_expression( ex; @@ -237,11 +252,14 @@ end operators end + ex = _normalize_expression_for_parse(ex, variable_names) tree = _parse_expression(ex, operators, variable_names, N, E, evaluate_on; kws...) return constructorof(E)(tree; operators, variable_names, kws...) end end +@unstable parse_expression(ex::String; kws...) = parse_expression(Meta.parse(ex); kws...) + """An empty module for evaluation without collisions.""" module EmptyModule end diff --git a/test/test_parse.jl b/test/test_parse.jl index 50121888..7e9c8f99 100644 --- a/test/test_parse.jl +++ b/test/test_parse.jl @@ -44,6 +44,51 @@ end end +@testitem "String parse treats Julia imaginary unit `im` as a constant" begin + using DynamicExpressions + using Test + + operators = OperatorEnum(; + binary_operators=[+, -, *, /], unary_operators=[], define_helper_functions=false + ) + + ex = parse_expression("0.1im + x"; operators, variable_names=["x"]) + @test typeof(ex) <: Expression{ComplexF64} + + function count_vars(n) + if n.degree == 0 + return n.constant ? 0 : 1 + elseif n.degree == 1 + return count_vars(n.l) + else + return count_vars(n.l) + count_vars(n.r) + end + end + + @test count_vars(ex.tree) == 1 + + # Check evaluation + eltype + Xr = reshape([1.0], 1, :) + yr = ex(Xr) + @test eltype(yr) == ComplexF64 + @test yr[1] ≈ 1.0 + 0.1im + + Xc = reshape([1.0 + 2.0im], 1, :) + yc = ex(Xc) + @test eltype(yc) == ComplexF64 + @test yc[1] ≈ 1.0 + 2.1im + + # If `"im"` is in `variable_names`, it should be treated as a variable + ex2 = parse_expression("im + x2"; operators, variable_names=["im", "x2"]) + @test typeof(ex2) <: Expression + @test count_vars(ex2.tree) == 2 + + X2 = reshape(Float32[2, 3], 2, :) + y2 = ex2(X2) + @test eltype(y2) == Float32 + @test y2[1] == 5.0f0 +end + @testitem "Can also parse just a float" begin using DynamicExpressions operators = OperatorEnum() # Tests empty operators