diff --git a/src/register_units.jl b/src/register_units.jl index 1bbfd129..387bbfe7 100644 --- a/src/register_units.jl +++ b/src/register_units.jl @@ -1,18 +1,36 @@ -import .Units: UNIT_MAPPING, UNIT_SYMBOLS, UNIT_VALUES, _lazy_register_unit +import .Units: UNIT_MAPPING, UNIT_SYMBOLS, UNIT_VALUES import .SymbolicUnits: update_external_symbolic_unit_value # Update the unit collections const UNIT_UPDATE_LOCK = Threads.SpinLock() +function update_all_values_unlocked(name_symbol, unit) + push!(UNIT_SYMBOLS, name_symbol) + push!(UNIT_VALUES, unit) + push!(ALL_SYMBOLS, name_symbol) + push!(ALL_VALUES, unit) + i = lastindex(ALL_VALUES) + ALL_MAPPING[name_symbol] = i + UNIT_MAPPING[name_symbol] = i + update_external_symbolic_unit_value(name_symbol) +end + function update_all_values(name_symbol, unit) lock(UNIT_UPDATE_LOCK) do - push!(ALL_SYMBOLS, name_symbol) - push!(ALL_VALUES, unit) - i = lastindex(ALL_VALUES) - ALL_MAPPING[name_symbol] = i - UNIT_MAPPING[name_symbol] = i - update_external_symbolic_unit_value(name_symbol) + index = get(ALL_MAPPING, name_symbol, INDEX_TYPE(0)) + if iszero(index) + update_all_values_unlocked(name_symbol, unit) + elseif ALL_VALUES[index] != unit + error("Unit `$name_symbol` is already defined as `$(ALL_VALUES[index])`") + end + end +end + +function define_unit_binding(mod::Module, name::Symbol, unit) + if !isdefined(mod, name) + Core.eval(mod, Expr(:const, Expr(:(=), name, QuoteNode(unit)))) end + return unit end """ @@ -46,10 +64,11 @@ julia> x * y^2 |> us"W^2" |> sqrt |> uexpand """ macro register_unit(symbol, value) - return esc(_register_unit(symbol, value)) + declare_external_unit(__module__, symbol) + return esc(_register_unit(__module__, symbol, value)) end -function _register_unit(name::Symbol, value) +function _register_unit(mod::Module, name::Symbol, value) name_symbol = Meta.quot(name) index = get(ALL_MAPPING, name, INDEX_TYPE(0)) if !iszero(index) @@ -60,13 +79,10 @@ function _register_unit(name::Symbol, value) # unit.value != value && throw("Unit $name is already defined as $unit") error("Unit `$name` is already defined as `$unit`") end - reg_expr = _lazy_register_unit(name, value) - push!( - reg_expr.args, - quote - $update_all_values($name_symbol, $value) - nothing - end - ) - return reg_expr + return quote + local unit = $value + $define_unit_binding($(QuoteNode(mod)), $name_symbol, unit) + $update_all_values($name_symbol, unit) + nothing + end end diff --git a/src/symbolic_dimensions.jl b/src/symbolic_dimensions.jl index 6513c49e..89743d99 100644 --- a/src/symbolic_dimensions.jl +++ b/src/symbolic_dimensions.jl @@ -415,7 +415,21 @@ module SymbolicUnits import ..DEFAULT_SYMBOLIC_QUANTITY_OUTPUT_TYPE import ..DEFAULT_VALUE_TYPE import ..DEFAULT_DIM_BASE_TYPE + import ..INDEX_TYPE + import ..ensure_registered_external_unit + import ..external_quantity_binding + import ..external_unit_declaration import ..WriteOnceReadMany + import ..disambiguate_constant_symbol + + symbolic_unit_from_symbol(unit::Symbol) = constructorof(DEFAULT_SYMBOLIC_QUANTITY_TYPE)( + DEFAULT_VALUE_TYPE(1.0), + SymbolicDimensionsSingleton{DEFAULT_DIM_BASE_TYPE}(unit) + ) + symbolic_constant_from_symbol(unit::Symbol) = constructorof(DEFAULT_SYMBOLIC_QUANTITY_TYPE)( + DEFAULT_VALUE_TYPE(1.0), + SymbolicDimensionsSingleton{DEFAULT_DIM_BASE_TYPE}(disambiguate_constant_symbol(unit)) + ) # Lazily create unit symbols (since there are so many) module Constants @@ -462,11 +476,7 @@ module SymbolicUnits # Non-eval version of `update_symbolic_unit_values!` for registering units in # an external module. function update_external_symbolic_unit_value(unit) - unit = constructorof(DEFAULT_SYMBOLIC_QUANTITY_TYPE)( - DEFAULT_VALUE_TYPE(1.0), - SymbolicDimensionsSingleton{DEFAULT_DIM_BASE_TYPE}(unit) - ) - push!(SYMBOLIC_UNIT_VALUES, unit) + push!(SYMBOLIC_UNIT_VALUES, symbolic_unit_from_symbol(unit)) end """ @@ -495,39 +505,45 @@ module SymbolicUnits as_quantity(x::Number) = convert(DEFAULT_SYMBOLIC_QUANTITY_OUTPUT_TYPE, x) as_quantity(x) = error("Unexpected type evaluated: $(typeof(x))") - @unstable function map_to_scope(ex::Expr) + @unstable map_to_scope(ex::Expr) = map_to_scope(@__MODULE__, ex) + @unstable function map_to_scope(mod::Module, ex::Expr) if !(ex.head == :call) && !(ex.head == :. && ex.args[1] == :Constants) throw(ArgumentError("Unexpected expression: $ex. Only `:call` and `:.` (for `SymbolicConstants`) are expected.")) end if ex.head == :call - ex.args[2:end] = map(map_to_scope, ex.args[2:end]) + ex.args[2:end] = map(arg -> map_to_scope(mod, arg), ex.args[2:end]) return ex else # if ex.head == :. && ex.args[1] == :Constants @assert ex.args[2] isa QuoteNode - return lookup_constant(ex.args[2].value) + return Expr(:call, GlobalRef(@__MODULE__, :lookup_constant), QuoteNode(ex.args[2].value)) end end - function map_to_scope(sym::Symbol) - if sym in UNIT_SYMBOLS - # return at end - elseif sym in CONSTANT_SYMBOLS + map_to_scope(sym::Symbol) = map_to_scope(@__MODULE__, sym) + function map_to_scope(mod::Module, sym::Symbol) + has_registered_binding = sym in UNIT_SYMBOLS + has_external_binding = !(mod === @__MODULE__) && ( + external_quantity_binding(mod, sym) || external_unit_declaration(mod, sym) + ) + + if !has_registered_binding && sym in CONSTANT_SYMBOLS throw(ArgumentError("Symbol $sym found in `Constants` but not `Units`. Please use `us\"Constants.$sym\"` instead.")) - else + elseif !has_registered_binding && !has_external_binding throw(ArgumentError("Symbol $sym not found in `Units` or `Constants`.")) + elseif has_external_binding + return Expr(:call, GlobalRef(@__MODULE__, :lookup_external_unit), QuoteNode(mod), QuoteNode(sym)) end - return lookup_unit(sym) - end - function map_to_scope(ex) - return ex - end - function lookup_unit(ex::Symbol) - i = findfirst(==(ex), UNIT_SYMBOLS)::Int - return as_quantity(SYMBOLIC_UNIT_VALUES[i]) + + return Expr(:call, GlobalRef(@__MODULE__, :lookup_unit), QuoteNode(sym)) end - function lookup_constant(ex::Symbol) - i = findfirst(==(ex), CONSTANT_SYMBOLS)::Int - return as_quantity(SYMBOLIC_CONSTANT_VALUES[i]) + map_to_scope(ex) = ex + map_to_scope(::Module, ex) = ex + + lookup_unit(ex::Symbol) = as_quantity(symbolic_unit_from_symbol(ex)) + function lookup_external_unit(mod::Module, sym::Symbol) + ensure_registered_external_unit(sym, getfield(mod, sym)) + return lookup_unit(sym) end + lookup_constant(ex::Symbol) = as_quantity(symbolic_constant_from_symbol(ex)) end import .SymbolicUnits: as_quantity, sym_uparse, SymbolicConstants, map_to_scope @@ -548,7 +564,7 @@ module. So, for example, `us"Constants.c^2 * Hz^2"` would evaluate to namespace collisions, a few physical constants are automatically converted. """ macro us_str(s) - ex = map_to_scope(Meta.parse(s)) + ex = map_to_scope(__module__, Meta.parse(s)) ex = :($as_quantity($ex)) return esc(ex) end diff --git a/src/units.jl b/src/units.jl index 2ae7608e..8f7018e5 100644 --- a/src/units.jl +++ b/src/units.jl @@ -302,6 +302,8 @@ end # Do not wish to define physical constants, as the number of symbols might lead to ambiguity. # The user should define these instead. +const BUILTIN_UNIT_SYMBOLS = Tuple(UNIT_SYMBOLS._raw_data) + # Update `UNIT_MAPPING` with all internally defined unit symbols. const UNIT_MAPPING = WriteOnceReadMany(Dict(s => i for (i, s) in enumerate(UNIT_SYMBOLS))) diff --git a/src/uparse.jl b/src/uparse.jl index 27f0ce60..53a4ec64 100644 --- a/src/uparse.jl +++ b/src/uparse.jl @@ -1,3 +1,39 @@ +const EXTERNAL_UNIT_DECLARATION_LOCK = Threads.SpinLock() +const EXTERNAL_UNIT_DECLARATIONS = IdDict{Module,Set{Symbol}}() + +function declare_external_unit(mod::Module, name::Symbol) + lock(EXTERNAL_UNIT_DECLARATION_LOCK) do + push!(get!(() -> Set{Symbol}(), EXTERNAL_UNIT_DECLARATIONS, mod), name) + end + return nothing +end + +function external_unit_declaration(mod::Module, name::Symbol) + lock(EXTERNAL_UNIT_DECLARATION_LOCK) do + declarations = get(EXTERNAL_UNIT_DECLARATIONS, mod, nothing) + return declarations !== nothing && name in declarations + end +end + +function external_quantity_binding(mod::Module, sym::Symbol) + return isdefined(mod, sym) && getfield(mod, sym) isa UnionAbstractQuantity +end + +function ensure_registered_external_unit(sym::Symbol, unit::UnionAbstractQuantity) + lock(UNIT_UPDATE_LOCK) do + if iszero(get(UNIT_MAPPING, sym, 0)) + update_all_values_unlocked(sym, unit) + end + end + return nothing +end + +function lookup_registered_unit(sym::Symbol) + i = get(UNIT_MAPPING, sym, 0) + iszero(i) && throw(ArgumentError("Symbol $sym not found in `Units`.")) + return ALL_VALUES[i] +end + module UnitsParse using DispatchDoctor: @unstable @@ -6,7 +42,11 @@ import ..constructorof import ..DEFAULT_QUANTITY_TYPE import ..DEFAULT_DIM_TYPE import ..DEFAULT_VALUE_TYPE -import ..Units: UNIT_SYMBOLS, UNIT_VALUES +import ..external_quantity_binding +import ..external_unit_declaration +import ..ensure_registered_external_unit +import ..lookup_registered_unit +import ..Units: UNIT_SYMBOLS import ..Constants: CONSTANT_SYMBOLS, CONSTANT_VALUES import ..Constants @@ -59,38 +99,50 @@ the quantity corresponding to the speed of light multiplied by Hertz, squared. """ macro u_str(s) - ex = map_to_scope(Meta.parse(s)) + ex = map_to_scope(__module__, Meta.parse(s)) ex = :($as_quantity($ex)) return esc(ex) end -@unstable function map_to_scope(ex::Expr) +@unstable map_to_scope(ex::Expr) = map_to_scope(@__MODULE__, ex) +@unstable function map_to_scope(mod::Module, ex::Expr) if !(ex.head == :call) && !(ex.head == :. && ex.args[1] == :Constants) throw(ArgumentError("Unexpected expression: $ex. Only `:call` and `:.` (for `Constants`) are expected.")) end if ex.head == :call - ex.args[2:end] = map(map_to_scope, ex.args[2:end]) + ex.args[2:end] = map(arg -> map_to_scope(mod, arg), ex.args[2:end]) return ex else # if ex.head == :. && ex.args[1] == :Constants @assert ex.args[2] isa QuoteNode - return lookup_constant(ex.args[2].value) + return Expr(:call, GlobalRef(@__MODULE__, :lookup_constant), QuoteNode(ex.args[2].value)) end end -function map_to_scope(sym::Symbol) - if sym in UNIT_SYMBOLS - return lookup_unit(sym) - elseif sym in CONSTANT_SYMBOLS +map_to_scope(sym::Symbol) = map_to_scope(@__MODULE__, sym) +function map_to_scope(mod::Module, sym::Symbol) + has_registered_binding = sym in UNIT_SYMBOLS + has_external_binding = !(mod === @__MODULE__) && ( + external_quantity_binding(mod, sym) || external_unit_declaration(mod, sym) + ) + + if !has_registered_binding && sym in CONSTANT_SYMBOLS throw(ArgumentError("Symbol $sym found in `Constants` but not `Units`. Please use `u\"Constants.$sym\"` instead.")) - else + elseif !has_registered_binding && !has_external_binding throw(ArgumentError("Symbol $sym not found in `Units` or `Constants`.")) + elseif has_external_binding + return Expr(:call, GlobalRef(@__MODULE__, :lookup_external_unit), QuoteNode(mod), QuoteNode(sym)) end + + return Expr(:call, GlobalRef(@__MODULE__, :lookup_unit), QuoteNode(sym)) end function map_to_scope(ex) return ex end -function lookup_unit(ex::Symbol) - i = findfirst(==(ex), UNIT_SYMBOLS)::Int - return UNIT_VALUES[i] +map_to_scope(::Module, ex) = ex + +@unstable lookup_unit(ex::Symbol) = lookup_registered_unit(ex) +@unstable function lookup_external_unit(mod::Module, sym::Symbol) + ensure_registered_external_unit(sym, getfield(mod, sym)) + return lookup_registered_unit(sym) end function lookup_constant(ex::Symbol) i = findfirst(==(ex), CONSTANT_SYMBOLS)::Int diff --git a/test/precompile_test/ExternalUnitRegistration.jl b/test/precompile_test/ExternalUnitRegistration.jl deleted file mode 100644 index 6395b9d2..00000000 --- a/test/precompile_test/ExternalUnitRegistration.jl +++ /dev/null @@ -1,21 +0,0 @@ -module ExternalUnitRegistration - -using DynamicQuantities: @register_unit, @u_str, @us_str -using DynamicQuantities: ALL_MAPPING, ALL_SYMBOLS, DEFAULT_QUANTITY_TYPE -using DynamicQuantities: DEFAULT_SYMBOLIC_QUANTITY_OUTPUT_TYPE, UNIT_SYMBOLS, UNIT_MAPPING -using Test - -@register_unit MyWb u"m^2*kg*s^-2*A^-1" - -@testset "Register Unit Inside a Module" begin - for collection in (UNIT_SYMBOLS, ALL_SYMBOLS, keys(ALL_MAPPING._raw_data), keys(UNIT_MAPPING._raw_data)) - @test :MyWb ∈ collection - end - - w = u"MyWb" - ws = us"MyWb" - @test w isa DEFAULT_QUANTITY_TYPE - @test ws isa DEFAULT_SYMBOLIC_QUANTITY_OUTPUT_TYPE -end - -end diff --git a/test/precompile_test/ExternalUnitRegistration/src/ExternalUnitRegistration.jl b/test/precompile_test/ExternalUnitRegistration/src/ExternalUnitRegistration.jl new file mode 100644 index 00000000..721f6305 --- /dev/null +++ b/test/precompile_test/ExternalUnitRegistration/src/ExternalUnitRegistration.jl @@ -0,0 +1,28 @@ +module ExternalUnitRegistration + +using DynamicQuantities: @register_unit, @u_str, @us_str +using DynamicQuantities: DEFAULT_QUANTITY_TYPE, DEFAULT_SYMBOLIC_QUANTITY_OUTPUT_TYPE + +@register_unit MyWb u"m^2*kg*s^-2*A^-1" + +function __init__() + @register_unit MyInitWb u"m^2*kg*s^-2*A^-1" +end + +const MYWB_EXPANDED = u"MyWb" + +expanded_mywb() = 1u"MyWb" +symbolic_mywb() = 1us"MyWb" +expanded_mywb_from_helper() = one(expanded_mywb()) * u"MyWb" +symbolic_mywb_from_helper() = one(symbolic_mywb()) * us"MyWb" +expanded_constant_mywb() = MYWB_EXPANDED +init_expanded_mywb() = 1u"MyInitWb" +init_symbolic_mywb() = 1us"MyInitWb" + +export MyWb +export MYWB_EXPANDED +export expanded_constant_mywb, expanded_mywb, expanded_mywb_from_helper +export symbolic_mywb, symbolic_mywb_from_helper +export init_expanded_mywb, init_symbolic_mywb + +end diff --git a/test/unittests.jl b/test/unittests.jl index 76553a6b..3e518517 100644 --- a/test/unittests.jl +++ b/test/unittests.jl @@ -710,6 +710,13 @@ end @test_throws "Symbol c found in `Constants` but not `Units`" sym_uparse("c") @test_throws "Unexpected expression" sym_uparse("import ..Units") @test_throws "Unexpected expression" sym_uparse("(m, m)") + + @eval module SymbolicUnitShadowingTest + using DynamicQuantities + const c = 1u"m" + end + @test_throws "Symbol c found in `Constants` but not `Units`" Core.eval(SymbolicUnitShadowingTest, :(us"c")) + @test Core.eval(SymbolicUnitShadowingTest, :(us"Constants.c")) == us"Constants.c" end @testset "Constants" begin @@ -975,6 +982,10 @@ end sym5 = dimension(us"km/s") @test_throws "my_special_symbol is not available as a symbol" sym5.my_special_symbol + # Exercise the no-module SymbolicUnits wrappers directly. + @test DynamicQuantities.SymbolicUnits.lookup_unit(:m) == us"m" + @test DynamicQuantities.SymbolicUnits.lookup_constant(:c) == us"Constants.c" + # Test deprecated method q = 1.5us"km/s" @test expand_units(q) == uexpand(q) @@ -2282,9 +2293,6 @@ end @test_throws DimensionError x^y end -# `@testset` rewrites the test block with a `let...end`, resulting in an invalid -# local `const` (ref: src/units.jl:26). To avoid it, register units outside the -# test block. map_count_before_registering = length(UNIT_MAPPING) all_map_count_before_registering = length(ALL_MAPPING) @@ -2301,16 +2309,23 @@ if :MySV2 ∉ UNIT_SYMBOLS @eval @register_unit MySV2 us"km/h" end -@test_throws "Unit `m` is already defined as `1.0 m`" esc(_register_unit(:m, u"s")) +@test_throws "Unit `m` is already defined as `1.0 m`" esc(_register_unit(@__MODULE__, :m, u"s")) # Constants as well: -@test_throws "Unit `Ryd` is already defined" esc(_register_unit(:Ryd, u"Constants.Ryd")) +@test_throws "Unit `Ryd` is already defined" esc(_register_unit(@__MODULE__, :Ryd, u"Constants.Ryd")) @testset "Register Unit" begin MyV = u"MyV" MySV = u"MySV" MySV2 = u"MySV2" + @test uparse("MyV") == u"V" + @test uparse("MySV") == u"V" + @test uparse("MySV2") == u"km/h" + @test sym_uparse("MyV") == us"MyV" + @test sym_uparse("MySV") == us"MySV" + @test sym_uparse("MySV2") == us"MySV2" + @test MyV === u"V" @test MyV == us"V" @test MySV == us"V" @@ -2334,11 +2349,43 @@ end push!(LOAD_PATH, joinpath(@__DIR__, "precompile_test")) -using ExternalUnitRegistration: MyWb +using ExternalUnitRegistration: MYWB_EXPANDED, MyWb +using ExternalUnitRegistration: expanded_constant_mywb, expanded_mywb, expanded_mywb_from_helper +using ExternalUnitRegistration: symbolic_mywb, symbolic_mywb_from_helper +using ExternalUnitRegistration: init_expanded_mywb, init_symbolic_mywb @testset "Type of External Unit" begin - @test MyWb isa DEFAULT_QUANTITY_TYPE - @test MyWb/u"m^2*kg*s^-2*A^-1" == 1.0 + @test MYWB_EXPANDED isa DEFAULT_QUANTITY_TYPE + @test MYWB_EXPANDED / u"m^2*kg*s^-2*A^-1" == 1.0 + @test u"MyWb" == MYWB_EXPANDED + @test uexpand(us"MyWb") == MYWB_EXPANDED + @test string(us"MyWb") == "1.0 MyWb" + @test MyWb == MYWB_EXPANDED + @test expanded_constant_mywb() == MYWB_EXPANDED + @test expanded_mywb() == MYWB_EXPANDED + @test expanded_mywb_from_helper() == expanded_mywb() + @test uexpand(symbolic_mywb()) == MYWB_EXPANDED + @test symbolic_mywb_from_helper() == symbolic_mywb() + @test string(symbolic_mywb()) == "1.0 MyWb" + @test init_expanded_mywb() == MYWB_EXPANDED + @test uexpand(init_symbolic_mywb()) == MYWB_EXPANDED + @test string(init_symbolic_mywb()) == "1.0 MyInitWb" end -pop!(LOAD_PATH) +@testset "Concurrent first-use registration" begin + if Threads.nthreads() > 1 + @eval module SymbolicUnitConcurrentRegistrationTest + using DynamicQuantities + const ConcurrentFooUnitForLazyRegistration = 1u"m" + parse_concurrent_symbol() = us"ConcurrentFooUnitForLazyRegistration" + end + results = Vector{Any}(undef, Threads.nthreads()) + Threads.@threads for i in eachindex(results) + results[i] = SymbolicUnitConcurrentRegistrationTest.parse_concurrent_symbol() + end + + @test all(x -> uexpand(x) == 1u"m", results) + end +end + +pop!(LOAD_PATH)