Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion marimo/_ast/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,18 @@ def compile_cell(
expr, filename, mode="eval", dont_inherit=True, flags=flags
)

nonlocals = {name for name in v.defs if not is_local(name)}
# Imports are exempt from the underscore "cell-local" rule: their
# name comes from the package (the user can't control it), and we
# need them in the graph so consumers in other cells can resolve.
nonlocals = {
name
for name in v.defs
if not is_local(name)
or (
name in v.variable_data
and any(d.kind == "import" for d in v.variable_data[name])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2: Import exemption too wide. It also exports user aliases like import x as _y as public defs. Keep alias-based _ imports cell-local.

(Based on your team's feedback about preserving backward compatibility for existing public behavior.)

View Feedback

Prompt for AI agents
Check if this issue is valid — if so, understand the root cause and fix it. At marimo/_ast/compiler.py, line 368:

<comment>Import exemption too wide. It also exports user aliases like `import x as _y` as public defs. Keep alias-based `_` imports cell-local.

(Based on your team's feedback about preserving backward compatibility for existing public behavior.) </comment>

<file context>
@@ -356,7 +356,18 @@ def compile_cell(
+        if not is_local(name)
+        or (
+            name in v.variable_data
+            and any(d.kind == "import" for d in v.variable_data[name])
+        )
+    }
</file context>

)
}
temporaries = v.defs - nonlocals
variable_data = {
name: v.variable_data[name]
Expand Down
29 changes: 18 additions & 11 deletions marimo/_ast/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,9 @@ def _get_alias_name(
) -> str:
"""Get the string name of an imported alias.

Mangles the "as" name if it's a local variable.
Imports are never mangled. The user can't control upstream
package names, and `import x as _y` already gets cell-local
semantics via `is_local` -> `temporaries` in `compile_cell`.

NB: We disallow `import *` because Python only allows
star imports at module-level, but we store cells as functions.
Expand All @@ -280,9 +282,6 @@ def _get_alias_name(
# import [a.b.c] - we define a
# from foo import [a] - we define a
# from foo import [*] - we don't define anything
#
# Note:
# Don't mangle - user has no control over package name
basename = node.name.split(".")[0]
if basename == "*":
# Use the ImportFrom node's line number for consistency
Expand All @@ -299,9 +298,7 @@ def _get_alias_name(
"is not allowed in marimo."
)
return basename
else:
node.asname = self._if_local_then_mangle(node.asname)
return node.asname
return node.asname

def _is_defined(self, identifier: str) -> bool:
"""Check if `identifier` is defined in any block."""
Expand Down Expand Up @@ -967,6 +964,14 @@ def visit_Name(self, node: ast.Name) -> ast.Name:
):
self._add_ref(node, node.id, deleted=True)
elif self.is_local(node.id):
# An unresolved underscore-prefixed Load is intentionally
# *not* added to refs. We could promote it so the dataflow
# wires it up to e.g. `from foo import _bar` in another
# cell, but that would expand the cross-cell reactive
# surface to every undefined underscore name. For now we
# accept the trade-off: cross-cell underscore imports work
# at runtime via the shared globals dict (the importer must
# have run), but the reactive graph won't track the edge.
mangled_name = self._if_local_then_mangle(
node.id, ignore_scope=True
)
Expand All @@ -978,10 +983,12 @@ def visit_Name(self, node: ast.Name) -> ast.Name:
# doesn't define it yet — this handles recursive calls to
# underscore-prefixed functions, where the function name
# isn't registered in the top-level block until after its
# body is visited.
if (
block.is_defined(mangled_name)
or len(self.block_stack) > 1
# body is visited. Skip mangling when the *unmangled*
# name is already defined at top level (e.g. by an
# underscore-prefixed import, which is never mangled).
if block.is_defined(mangled_name) or (
len(self.block_stack) > 1
and not block.is_defined(node.id)
):
node.id = mangled_name
elif block.is_defined(node.id):
Expand Down
17 changes: 14 additions & 3 deletions marimo/_save/hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,17 +690,28 @@ def serialize_and_dequeue_content_refs(
imports = get_imports(scope)
for local_ref in sorted(refs):
ref = if_local_then_mangle(local_ref, self.cell_id)
if ref in imports:
# Imports are never mangled (see `_get_alias_name`), so the
# `_private`-style import will appear in `imports` under its
# raw name while non-import locals are still keyed by their
# mangled form. Accept either.
import_key: Name | None = (
ref
if ref in imports
else local_ref
if local_ref in imports
else None
)
if import_key is not None:
# TODO: There may be a way to tie this in with module watching.
# e.g. module watcher could mutate the version number based
# last updated timestamp.
version = ""
module = None
if self.pin_modules:
module = sys.modules[imports[ref].module]
module = sys.modules[imports[import_key].module]
version = getattr(module, "__version__", "")
if not version:
module = sys.modules[imports[ref].namespace]
module = sys.modules[imports[import_key].namespace]
version = getattr(module, "__version__", "")

content_serialization[ref] = type_sign(
Expand Down
31 changes: 31 additions & 0 deletions tests/_ast/test_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -985,6 +985,37 @@ def test_relative_from_import() -> None:
]


def test_underscore_imports_never_mangled() -> None:
# Imports must never be mangled — the user can't control upstream
# package names. Covers `from foo import _bar`, multi-name imports,
# plain `import _foo`, and `import x as _y`. Also checks that
# underscore-imported names referenced from a nested scope
# (decorator) stay unmangled.
cases = [
("from foo import _bar", {"_bar"}),
("from foo import _bar, _baz, _qux", {"_bar", "_baz", "_qux"}),
("import _foo", {"_foo"}),
("import marimo as _mo", {"_mo"}),
("from a.b import _c as _d", {"_d"}),
(
"import marimo as _private\n"
"@_private.cache\n"
"def f(x):\n"
" return _private.md('x')\n",
{"_private", "f"},
),
]
for source, expected_defs in cases:
v = visitor.ScopedVisitor(mangle_prefix="cell_test")
mod = ast.parse(source)
v.visit(mod)
assert v.defs == expected_defs, source
unparsed = ast.unparse(mod)
assert "_cell_test_" not in unparsed, (
f"import was mangled in {source!r}: {unparsed!r}"
)


def test_from_import_star() -> None:
expr = "from a.b.c import *"
v = visitor.ScopedVisitor()
Expand Down
49 changes: 46 additions & 3 deletions tests/_runtime/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,7 +877,9 @@ async def test_delete_nonlocal_incremental_ref_raises_name_error(
async def test_import_module_as_local_var(
self, any_kernel: Kernel
) -> None:
# Tests that imported names are mangled but still usable
# Imports are never mangled (see `_get_alias_name`) and are
# exposed as graph defs even when underscore-prefixed, so they
# can be referenced across cells.
k = any_kernel
await k.run(
[
Expand All @@ -887,9 +889,50 @@ async def test_import_module_as_local_var(
),
]
)
# _sys mangled, should not be in globals
assert "_sys" not in k.globals
assert k.globals["msize"] == sys.maxsize
assert "_sys" in k.graph.cells["0"].defs

async def test_underscore_prefixed_import_in_cell(
self, any_kernel: Kernel
) -> None:
# An underscore-prefixed `from x import _y` in a single cell
# must resolve when used in the same cell.
k = any_kernel
await k.run(
[
ExecuteCellCommand(
cell_id="0",
code=(
"from marimo import _output\nmsg = _output.md.md('hi')"
),
),
]
)
assert not k.errors, k.errors
assert "hi" in k.globals["msg"].text

async def test_underscore_prefixed_import_across_cells(
self, k: Kernel
) -> None:
# Cross-cell: one cell does `from x import _y`, another uses
# `_y`. The consumer resolves at runtime via shared globals
# (the importer ran first). No reactive edge is created — the
# consumer is not in the importer's children — so editing the
# importer won't automatically re-run the consumer. We accept
# that trade-off rather than promoting every undefined
# underscore Load to a cross-cell ref.
importer = ExecuteCellCommand(
cell_id="0",
code="from marimo import _output",
)
consumer = ExecuteCellCommand(
cell_id="1",
code="msg = _output.md.md('cross')",
)
await k.run([importer, consumer])
assert not k.errors, k.errors
assert "cross" in k.globals["msg"].text
assert "1" not in k.graph.children.get("0", set())

async def test_cell_transitioned_to_error_is_not_stale(
self, lazy_kernel: Kernel
Expand Down
56 changes: 56 additions & 0 deletions tests/_save/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,6 +1111,62 @@ def foo():
assert k.stdout.messages.count("ran") == 1
assert k.globals["foo"].hits == 3

async def test_lru_cache_underscore_alias(
self, k: Kernel, exec_req: ExecReqProvider
) -> None:
# Regression: `import marimo as _private` (underscore alias) used to
# blow up cache hashing because mangled scope keys (`_cell_<id>__private`)
# don't match the unmangled-ref lookup used during the cache attempt.
await k.run(
[
exec_req.get(
"""
import marimo as _private

@_private.lru_cache(maxsize=128)
def f(x):
return x * 2

a = f(1)
b = f(1)
c = f(2)
"""
),
]
)

assert not k.stderr.messages, k.stderr
assert k.globals["a"] == 2
assert k.globals["b"] == 2
assert k.globals["c"] == 4
# Second `f(1)` should hit the cache.
assert k.globals["f"].hits == 1

async def test_cache_underscore_alias(
self, k: Kernel, exec_req: ExecReqProvider
) -> None:
await k.run(
[
exec_req.get(
"""
import marimo as _private

@_private.cache
def f(x):
return x * 3

a = f(1)
b = f(1)
"""
),
]
)

assert not k.stderr.messages, k.stderr
assert k.globals["a"] == 3
assert k.globals["b"] == 3
assert k.globals["f"].hits == 1

async def test_persistent_cache(
self, k: Kernel, exec_req: ExecReqProvider
) -> None:
Expand Down
Loading