From bfc014ac294b5df468a534b942b6de2c0a723ea5 Mon Sep 17 00:00:00 2001 From: JMPRIEST Date: Thu, 27 Apr 2023 17:04:35 +0800 Subject: [PATCH 01/20] =?UTF-8?q?:construction:=20create=20&=20drop=20name?= =?UTF-8?q?space=E8=AF=AD=E6=B3=95=E5=92=8Cpg=20schema=E8=90=BD=E5=9C=B0?= =?UTF-8?q?=E6=89=80=E9=9C=80=E4=BB=A3=E7=A0=81=E5=87=86=E5=A4=87=E5=9F=BA?= =?UTF-8?q?=E6=9C=AC=E5=AE=8C=E6=88=90=EF=BC=8C=E5=BE=85=E8=A1=A5edb=20sch?= =?UTF-8?q?ema=E6=9B=B4=E6=96=B0=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- edb/edgeql-parser/src/keywords.rs | 1 + edb/edgeql/ast.py | 14 ++++ edb/edgeql/codegen.py | 6 ++ edb/edgeql/parser/grammar/ddl.py | 65 ++++++++++++++++++ edb/edgeql/qltypes.py | 1 + edb/lib/sys.edgeql | 2 + edb/pgsql/dbops/__init__.py | 1 + edb/pgsql/dbops/namespace.py | 56 +++++++++++++++ edb/pgsql/delta.py | 109 ++++++++++++++++++++++++++++++ edb/pgsql/metaschema.py | 101 +++++++++++++++++++++++++++ edb/schema/namespace.py | 78 +++++++++++++++++++++ edb/server/bootstrap.py | 80 ++++++++++++++++++++++ edb/server/compiler/compiler.py | 5 ++ edb/server/compiler/dbstate.py | 5 ++ edb/server/pgcluster.py | 2 +- edb/server/protocol/execute.pyx | 2 + edb/server/server.py | 26 ++++++- 17 files changed, 552 insertions(+), 2 deletions(-) create mode 100644 edb/pgsql/dbops/namespace.py create mode 100644 edb/schema/namespace.py diff --git a/edb/edgeql-parser/src/keywords.rs b/edb/edgeql-parser/src/keywords.rs index 3d55fd69955..732c08e9db2 100644 --- a/edb/edgeql-parser/src/keywords.rs +++ b/edb/edgeql-parser/src/keywords.rs @@ -20,6 +20,7 @@ pub const UNRESERVED_KEYWORDS: &[&str] = &[ "cube", "current", "database", + "namespace", "ddl", "declare", "default", diff --git a/edb/edgeql/ast.py b/edb/edgeql/ast.py index dff05cf2491..7d79b4ef2b0 100644 --- a/edb/edgeql/ast.py +++ b/edb/edgeql/ast.py @@ -836,6 +836,20 @@ class DropDatabase(DropObject, DatabaseCommand): pass +class NameSpaceCommand(ExternalObjectCommand): + __abstract_node__ = True + object_class: qltypes.SchemaObjectClass = ( + qltypes.SchemaObjectClass.NAMESPACE) + + +class CreateNameSpace(CreateObject, NameSpaceCommand): + pass + + +class DropNameSpace(DropObject, NameSpaceCommand): + pass + + class ExtensionPackageCommand(GlobalObjectCommand): __abstract_node__ = True object_class: qltypes.SchemaObjectClass = ( diff --git a/edb/edgeql/codegen.py b/edb/edgeql/codegen.py index 55ddb58a8fc..6f5b1c4105d 100644 --- a/edb/edgeql/codegen.py +++ b/edb/edgeql/codegen.py @@ -1033,6 +1033,12 @@ def visit_AlterDatabase(self, node: qlast.AlterDatabase) -> None: def visit_DropDatabase(self, node: qlast.DropDatabase) -> None: self._visit_DropObject(node, 'DATABASE') + def visit_CreateNameSpace(self, node: qlast.CreateNameSpace) -> None: + self._visit_CreateObject(node, 'NAMESPACE') + + def visit_DropNameSpace(self, node: qlast.DropNameSpace) -> None: + self._visit_DropObject(node, 'NAMESPACE') + def visit_CreateRole(self, node: qlast.CreateRole) -> None: after_name = lambda: self._ddl_visit_bases(node) keywords = [] diff --git a/edb/edgeql/parser/grammar/ddl.py b/edb/edgeql/parser/grammar/ddl.py index a616ce3ac8a..8a8c92d9d49 100644 --- a/edb/edgeql/parser/grammar/ddl.py +++ b/edb/edgeql/parser/grammar/ddl.py @@ -56,6 +56,9 @@ class DDLStmt(Nonterm): def reduce_DatabaseStmt(self, *kids): self.val = kids[0].val + def reduce_NameSpaceStmt(self, *kids): + self.val = kids[0].val + def reduce_RoleStmt(self, *kids): self.val = kids[0].val @@ -664,6 +667,68 @@ class DropDatabaseStmt(Nonterm): def reduce_DROP_DATABASE_DatabaseName(self, *kids): self.val = qlast.DropDatabase(name=kids[2].val) +# +# NAMESPACE +# + + +class NameSpaceName(Nonterm): + + def reduce_Identifier(self, kid): + self.val = qlast.ObjectRef( + module=None, + name=kid.val + ) + + def reduce_ReservedKeyword(self, *kids): + name = kids[0].val + if ( + name[:2] == '__' and name[-2:] == '__' + ): + raise EdgeQLSyntaxError( + "identifiers surrounded by double underscores are forbidden", + context=kids[0].context) + + self.val = qlast.ObjectRef( + module=None, + name=name + ) + + +class NameSpaceStmt(Nonterm): + + def reduce_CreateNameSpaceStmt(self, *kids): + self.val = kids[0].val + + def reduce_DropNameSpaceStmt(self, *kids): + self.val = kids[0].val + + +# +# CREATE NAMESPACE +# + + +commands_block( + 'CreateNameSpace', + SetFieldStmt, +) + + +class CreateNameSpaceStmt(Nonterm): + def reduce_CREATE_NAMESPACE_NameSpaceName(self, *kids): + """%reduce CREATE NAMESPACE NameSpaceName + """ + self.val = qlast.CreateNameSpace(name=kids[2].val) + + +# +# DROP NAMESPACE +# +class DropNameSpaceStmt(Nonterm): + def reduce_DROP_NAMESPACE_DatabaseName(self, *kids): + self.val = qlast.DropNameSpace(name=kids[2].val) + # # EXTENSION PACKAGE diff --git a/edb/edgeql/qltypes.py b/edb/edgeql/qltypes.py index 090c8e8a651..42048e5041a 100644 --- a/edb/edgeql/qltypes.py +++ b/edb/edgeql/qltypes.py @@ -237,6 +237,7 @@ class SchemaObjectClass(s_enum.StrEnum): SCALAR_TYPE = 'SCALAR TYPE' TUPLE_TYPE = 'TUPLE TYPE' TYPE = 'TYPE' + NAMESPACE = 'NAMESPACE' class LinkTargetDeleteAction(s_enum.StrEnum): diff --git a/edb/lib/sys.edgeql b/edb/lib/sys.edgeql index 98738ca4d59..fc371c10562 100644 --- a/edb/lib/sys.edgeql +++ b/edb/lib/sys.edgeql @@ -37,6 +37,8 @@ CREATE TYPE sys::Database EXTENDING sys::SystemObject { }; }; +CREATE TYPE sys::NameSpace EXTENDING sys::SystemObject; + CREATE TYPE sys::ExtensionPackage EXTENDING sys::SystemObject { CREATE REQUIRED PROPERTY script -> str; diff --git a/edb/pgsql/dbops/__init__.py b/edb/pgsql/dbops/__init__.py index dc8fdb4a0b4..6e9b1212b7b 100644 --- a/edb/pgsql/dbops/__init__.py +++ b/edb/pgsql/dbops/__init__.py @@ -40,3 +40,4 @@ from .triggers import * # NOQA from .types import * # NOQA from .views import * # NOQA +from .namespace import * # NOQA diff --git a/edb/pgsql/dbops/namespace.py b/edb/pgsql/dbops/namespace.py new file mode 100644 index 00000000000..86a6df68d06 --- /dev/null +++ b/edb/pgsql/dbops/namespace.py @@ -0,0 +1,56 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2008-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from __future__ import annotations + +from . import base +from . import ddl +from ..common import quote_ident as qi + + +class NameSpace: + def __init__( + self, + name: str, + ): + self.name = name + + def get_type(self): + return 'NAMESPACE' + + def get_id(self): + return qi(self.name) + + def is_shared(self) -> bool: + return False + + +class DropNameSpace( + ddl.SchemaObjectOperation, + ddl.NonTransactionalDDLOperation +): + + def code(self, block: base.PLBlock) -> str: + schemas = ",".join( + [ + qi(self.name.name + "_" + schema) + for schema in ['edgedb', 'edgedbss', 'edgedbpub', 'edgedbstd', 'edgedbinstdata', 'edgedbext'] + ] + ) + return f'DROP SCHEMA IF EXISTS {schemas} CASCADE;' diff --git a/edb/pgsql/delta.py b/edb/pgsql/delta.py index 0f612602311..0bed205d171 100644 --- a/edb/pgsql/delta.py +++ b/edb/pgsql/delta.py @@ -52,6 +52,7 @@ from edb.schema import properties as s_props from edb.schema import migrations as s_migrations from edb.schema import modules as s_mod +from edb.schema import namespace as s_ns from edb.schema import name as sn from edb.schema import objects as so from edb.schema import operators as s_opers @@ -6838,6 +6839,114 @@ class DeleteModule(ModuleMetaCommand, adapts=s_mod.DeleteModule): pass +class NameSpaceMetaCommand(MetaCommand): + pass + + +class CreateNameSpace(NameSpaceMetaCommand, adapts=s_ns.CreateNameSpace): + def apply( + self, + schema: s_schema.Schema, + context: sd.CommandContext, + ) -> s_schema.Schema: + schema = super().apply(schema, context) + + ns_id = str(self.scls.id) + name__internal = str(self.scls.get_name(schema)) + name = self.scls.get_displayname(schema) + + metadata = { + ns_id: { + 'id': ns_id, + 'name': name, + 'name__internal': name__internal, + 'builtin': self.scls.get_builtin(schema), + 'internal': self.scls.get_internal(schema), + } + } + + ctx_backend_params = context.backend_runtime_params + if ctx_backend_params is not None: + backend_params = cast( + params.BackendRuntimeParams, ctx_backend_params + ) + else: + backend_params = params.get_default_runtime_params() + + if backend_params.has_create_database: + tenant_id = self._get_tenant_id(context) + tpl_db_name = common.get_database_backend_name( + edbdef.EDGEDB_TEMPLATE_DB, tenant_id=tenant_id + ) + + self.pgops.add( + dbops.UpdateMetadataSection( + dbops.Database(name=tpl_db_name), + section='NameSpace', + metadata=metadata + ) + ) + else: + self.pgops.add( + dbops.UpdateSingleDBMetadataSection( + edbdef.EDGEDB_TEMPLATE_DB, + section='NameSpace', + metadata=metadata + ) + ) + + return schema + + +class DeleteNameSpace(NameSpaceMetaCommand, adapts=s_ns.DeleteNameSpace): + def apply( + self, + schema: s_schema.Schema, + context: sd.CommandContext, + ) -> s_schema.Schema: + schema = super().apply(schema, context) + ns_id = str(self.scls.id) + + ctx_backend_params = context.backend_runtime_params + if ctx_backend_params is not None: + backend_params = cast( + params.BackendRuntimeParams, ctx_backend_params + ) + else: + backend_params = params.get_default_runtime_params() + + tpl_db_name = common.get_database_backend_name( + edbdef.EDGEDB_TEMPLATE_DB, tenant_id=backend_params.tenant_id + ) + + metadata = { + ns_id: None + } + if backend_params.has_create_database: + self.pgops.add( + dbops.UpdateMetadataSection( + dbops.Database(name=tpl_db_name), + section='NameSpace', + metadata=metadata + ) + ) + else: + self.pgops.add( + dbops.UpdateSingleDBMetadataSection( + edbdef.EDGEDB_TEMPLATE_DB, + section='NameSpace', + metadata=metadata + ) + ) + + self.pgops.add( + dbops.DropNameSpace( + dbops.NameSpace(str(self.classname)) + ) + ) + return schema + + class DatabaseMixin: def ensure_has_create_database(self, backend_params): if not backend_params.has_create_database: diff --git a/edb/pgsql/metaschema.py b/edb/pgsql/metaschema.py index 0928373c9f0..eb5064fbc64 100644 --- a/edb/pgsql/metaschema.py +++ b/edb/pgsql/metaschema.py @@ -4803,6 +4803,107 @@ def _generate_schema_ver_views(schema: s_schema.Schema) -> List[dbops.View]: return views +def _generate_namespace_views(schema: s_schema.Schema) -> List[dbops.View]: + NameSpace = schema.get('sys::NameSpace', type=s_objtypes.ObjectType) + annos = NameSpace.getptr( + schema, s_name.UnqualName('annotations'), type=s_links.Link) + int_annos = NameSpace.getptr( + schema, s_name.UnqualName('annotations__internal'), type=s_links.Link) + + view_query = f''' + SELECT + ((d.description)->>'id')::uuid + AS {qi(ptr_col_name(schema, NameSpace, 'id'))}, + (SELECT id FROM edgedb."_SchemaObjectType" + WHERE name = 'sys::NameSpace') + AS {qi(ptr_col_name(schema, NameSpace, '__type__'))}, + False AS {qi(ptr_col_name(schema, NameSpace, 'internal'))}, + (d.description)->>'name' + AS {qi(ptr_col_name(schema, NameSpace, 'name'))}, + (d.description)->>'name' + AS {qi(ptr_col_name(schema, NameSpace, 'name__internal'))}, + ARRAY[]::text[] + AS {qi(ptr_col_name(schema, NameSpace, 'computed_fields'))}, + ((d.description)->>'builtin')::bool + AS {qi(ptr_col_name(schema, NameSpace, 'builtin'))}, + (d.description)->>'module_name' + AS {qi(ptr_col_name(schema, NameSpace, 'module_name'))}, + ((d.description)->>'external')::bool + AS {qi(ptr_col_name(schema, NameSpace, 'external'))} + FROM + pg_database dat + CROSS JOIN LATERAL ( + SELECT + edgedb.shobj_metadata(dat.oid, 'pg_database') + AS description + ) AS d + WHERE + (d.description)->>'id' IS NOT NULL + AND (d.description)->>'tenant_id' = edgedb.get_backend_tenant_id() + ''' + + annos_link_query = f''' + SELECT + ((d.description)->>'id')::uuid + AS {qi(ptr_col_name(schema, annos, 'source'))}, + (annotations->>'id')::uuid + AS {qi(ptr_col_name(schema, annos, 'target'))}, + (annotations->>'value')::text + AS {qi(ptr_col_name(schema, annos, 'value'))}, + (annotations->>'owned')::bool + AS {qi(ptr_col_name(schema, annos, 'owned'))} + FROM + pg_database dat + CROSS JOIN LATERAL ( + SELECT + edgedb.shobj_metadata(dat.oid, 'pg_database') + AS description + ) AS d + CROSS JOIN LATERAL + ROWS FROM ( + jsonb_array_elements((d.description)->'annotations') + ) AS annotations + ''' + + int_annos_link_query = f''' + SELECT + ((d.description)->>'id')::uuid + AS {qi(ptr_col_name(schema, int_annos, 'source'))}, + (annotations->>'id')::uuid + AS {qi(ptr_col_name(schema, int_annos, 'target'))}, + (annotations->>'owned')::bool + AS {qi(ptr_col_name(schema, int_annos, 'owned'))} + FROM + pg_database dat + CROSS JOIN LATERAL ( + SELECT + edgedb.shobj_metadata(dat.oid, 'pg_database') + AS description + ) AS d + CROSS JOIN LATERAL + ROWS FROM ( + jsonb_array_elements( + (d.description)->'annotations__internal' + ) + ) AS annotations + ''' + + objects = { + Database: view_query, + annos: annos_link_query, + int_annos: int_annos_link_query, + } + + views = [] + for obj, query in objects.items(): + tabview = dbops.View(name=tabname(schema, obj), query=query) + inhview = dbops.View(name=inhviewname(schema, obj), query=query) + views.append(tabview) + views.append(inhview) + + return views + + def _make_json_caster( schema: s_schema.Schema, stype: s_types.Type, diff --git a/edb/schema/namespace.py b/edb/schema/namespace.py new file mode 100644 index 00000000000..734e5bc27bc --- /dev/null +++ b/edb/schema/namespace.py @@ -0,0 +1,78 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2008-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from __future__ import annotations + +from edb import errors +from edb.edgeql import ast as qlast +from edb.edgeql import qltypes +from . import abc as s_abc +from . import annos as s_anno +from . import delta as sd +from . import objects as so +from . import schema as s_schema + + +class NameSpace( + so.ExternalObject, + s_anno.AnnotationSubject, + s_abc.NameSpace, + qlkind=qltypes.SchemaObjectClass.NAMESPACE, + data_safe=False, +): + pass + + +class NameSpaceCommandContext(sd.ObjectCommandContext[NameSpace]): + pass + + +class NameSpaceCommand( + sd.ExternalObjectCommand[NameSpace], + context_class=NameSpaceCommandContext, +): + def _validate_name( + self, + schema: s_schema.Schema, + context: sd.CommandContext, + ) -> None: + name = self.get_attribute_value('name') + if str(name).startswith('pg_'): + source_context = self.get_attribute_source_context('name') + raise errors.SchemaDefinitionError( + f'NameSpace names can not be started with \'pg_\', ' + f'as such names are reserved for system schemas', + context=source_context, + ) + + +class CreateNameSpace(NameSpaceCommand, sd.CreateExternalObject[NameSpace]): + astnode = qlast.CreateNameSpace + + def validate_create( + self, + schema: s_schema.Schema, + context: sd.CommandContext, + ) -> None: + super().validate_create(schema, context) + self._validate_name(schema, context) + + +class DeleteNameSpace(NameSpaceCommand, sd.DeleteExternalObject[NameSpace]): + astnode = qlast.DropNameSpace diff --git a/edb/server/bootstrap.py b/edb/server/bootstrap.py index 291cf36ac46..e6c4e161cbd 100644 --- a/edb/server/bootstrap.py +++ b/edb/server/bootstrap.py @@ -998,6 +998,73 @@ async def _init_stdlib( return stdlib, config_spec, compiler +async def store_tpl_sql(tpldbdump: bytes, conn: pgcon.PGConnection): + text = f"""\ + INSERT INTO edgedbinstdata.instdata (key, text) + VALUES( + {pg_common.quote_literal('tpl_sql')}, + {pg_common.quote_literal(tpldbdump.decode('utf-8'))}::text + ) + """ + + await _execute(conn, text) + + +async def gen_tpl_dump(cluster: pgcluster.BaseCluster): + tpl_db_name = edbdef.EDGEDB_TEMPLATE_DB + tpl_pg_db_name = cluster.get_db_name(tpl_db_name) + tpldbdump = await cluster.dump_database( + tpl_pg_db_name, + exclude_schemas=['edgedbext'], + dump_object_owners=False, + ) + commands = [dbops.CreateSchema(name='{ns_edgedbext}')] + for uuid_func in [ + 'uuid_generate_v1', + 'uuid_generate_v1mc', + 'uuid_generate_v4', + 'uuid_nil', + 'uuid_ns_dns', + 'uuid_ns_oid', + 'uuid_ns_url', + 'uuid_ns_x500', + ]: + commands.append( + dbops.CreateOrReplaceFunction( + dbops.Function( + name=('{ns_edgedbext}', uuid_func), + returns=('pg_catalog', 'uuid'), language='plpgsql', + text=f""" + BEGIN + RETURN edgedbext.{uuid_func}(); + END; + """ + ) + ) + ) + + for uuid_func in ['uuid_generate_v3', 'uuid_generate_v5']: + commands.append( + dbops.CreateOrReplaceFunction( + dbops.Function( + name=('{ns_edgedbext}', uuid_func), + returns=('pg_catalog', 'uuid'), language='plpgsql', + args=[('namespace', 'uuid'), ('name', 'text')], + text=f""" + BEGIN + RETURN edgedbext.{uuid_func}(namespace, text); + END; + """ + ) + ) + ) + command_group = dbops.CommandGroup() + command_group.add_commands(commands) + block = dbops.PLTopBlock() + command_group.generate(block) + return block.to_string().encode('utf-8') + tpldbdump + + async def _init_defaults(schema, compiler, conn): script = ''' CREATE MODULE default; @@ -1359,6 +1426,17 @@ async def _get_instance_data(conn: pgcon.PGConnection) -> Dict[str, Any]: return json.loads(data) +async def get_tpl_sql(conn: pgcon.PGConnection) -> bytes: + data = await conn.sql_fetch_val( + b""" + SELECT text + FROM edgedbinstdata.instdata + WHERE key = 'tpl_sql' + """, + ) + return data + + async def _check_catalog_compatibility( ctx: BootstrapContext, ) -> pgcon.PGConnection: @@ -1504,6 +1582,8 @@ async def _start(ctx: BootstrapContext) -> None: # Initialize global config config.set_settings(config_spec) + if ctx.cluster._pg_bin_dir is None: + await ctx.cluster.lookup_postgres() finally: conn.terminate() diff --git a/edb/server/compiler/compiler.py b/edb/server/compiler/compiler.py index 0aab3ba1bc6..c558900d267 100644 --- a/edb/server/compiler/compiler.py +++ b/edb/server/compiler/compiler.py @@ -1119,11 +1119,14 @@ def _compile_and_apply_ddl_stmt( create_db = None drop_db = None create_db_template = None + create_ns = None if isinstance(stmt, qlast.DropDatabase): drop_db = stmt.name.name elif isinstance(stmt, qlast.CreateDatabase): create_db = stmt.name.name create_db_template = stmt.template.name if stmt.template else None + elif isinstance(stmt, qlast.CreateNameSpace): + create_ns = stmt.name.name if debug.flags.delta_execute: debug.header('Delta Script') @@ -1140,6 +1143,7 @@ def _compile_and_apply_ddl_stmt( ), create_db=create_db, drop_db=drop_db, + create_ns=create_ns, create_db_template=create_db_template, has_role_ddl=isinstance(stmt, qlast.RoleCommand), ddl_stmt_id=ddl_stmt_id, @@ -2050,6 +2054,7 @@ def _try_compile( unit.create_db = comp.create_db unit.drop_db = comp.drop_db unit.create_db_template = comp.create_db_template + unit.create_ns = comp.create_ns unit.has_role_ddl = comp.has_role_ddl unit.ddl_stmt_id = comp.ddl_stmt_id if comp.user_schema is not None: diff --git a/edb/server/compiler/dbstate.py b/edb/server/compiler/dbstate.py index 25667df076b..e3a4417b010 100644 --- a/edb/server/compiler/dbstate.py +++ b/edb/server/compiler/dbstate.py @@ -135,6 +135,7 @@ class DDLQuery(BaseQuery): is_transactional: bool = True single_unit: bool = False create_db: Optional[str] = None + create_ns: Optional[str] = None drop_db: Optional[str] = None create_db_template: Optional[str] = None has_role_ddl: bool = False @@ -260,6 +261,10 @@ class QueryUnit: # close all inactive unused pooled connections to the template db. create_db_template: Optional[str] = None + # If non-None, contains a name of the DB that is about to be + # created/deleted. + create_ns: Optional[str] = None + # If non-None, the DDL statement will emit data packets marked # with the indicated ID. ddl_stmt_id: Optional[str] = None diff --git a/edb/server/pgcluster.py b/edb/server/pgcluster.py index 306ac1146d6..6c65180f085 100644 --- a/edb/server/pgcluster.py +++ b/edb/server/pgcluster.py @@ -1174,7 +1174,7 @@ async def _start_logged_subprocess( asyncio.subprocess.PIPE if log_stderr or capture_stderr else asyncio.subprocess.DEVNULL ), - limit=2 ** 20, # 1 MiB + limit=2 ** 25, # 32 MiB **kwargs, ) diff --git a/edb/server/protocol/execute.pyx b/edb/server/protocol/execute.pyx index b44e29084be..f4b1654caf2 100644 --- a/edb/server/protocol/execute.pyx +++ b/edb/server/protocol/execute.pyx @@ -84,6 +84,8 @@ async def execute( await server._on_before_drop_db(query_unit.drop_db, dbv.dbname) if query_unit.system_config: await execute_system_config(be_conn, dbv, query_unit) + if query_unit.create_ns: + await server.create_namespace(be_conn, query_unit.create_ns) else: config_ops = query_unit.config_ops diff --git a/edb/server/server.py b/edb/server/server.py index 8e539b96708..363e6d38e79 100644 --- a/edb/server/server.py +++ b/edb/server/server.py @@ -22,6 +22,7 @@ import contextlib import functools import hashlib +import re from typing import * import asyncio @@ -66,7 +67,7 @@ from edb.server import metrics from edb.server import pgcon from edb.server.pgcon import errors as pgcon_errors - +from edb.server.bootstrap import get_tpl_sql, gen_tpl_dump, store_tpl_sql from . import dbview if TYPE_CHECKING: @@ -273,6 +274,7 @@ def __init__( self._admin_ui = admin_ui self._db_to_bigint = {} + self._ns_tpl_sql = None @contextlib.asynccontextmanager async def aquire_distributed_lock(self, dbname, conn): @@ -970,6 +972,13 @@ async def _load_instance_data(self): WHERE key = 'report_configs_typedesc'; ''') + if (tpldbdump := await get_tpl_sql(syscon)) is None: + tpldbdump = await gen_tpl_dump(self._cluster) + await store_tpl_sql(tpldbdump, syscon) + self._ns_tpl_sql = tpldbdump + else: + self._ns_tpl_sql = tpldbdump + finally: self._release_sys_pgcon() @@ -1919,6 +1928,21 @@ def on_switch_over(self): call_on_switch_over=False ) + async def create_namespace(self, be_conn: pgcon.PGConnection, name: str): + tpl_sql = re.sub( + rb'(edgedb)(\.|instdata|pub|ss|std|;)', + name.encode('utf-8') + rb'_\1\2', + self._ns_tpl_sql, + flags=re.MULTILINE, + ) + tpl_sql = re.sub( + rb'({ns_edgedbext})', + name.encode('utf-8') + rb'_edgedbext', + tpl_sql, + flags=re.MULTILINE, + ) + await be_conn.sql_execute(tpl_sql) + def get_active_pgcon_num(self) -> int: return ( self._pg_pool.current_capacity - self._pg_pool.get_pending_conns() From 272f64eafdb056d17599e7dac757f2b0e0e05878 Mon Sep 17 00:00:00 2001 From: JMPRIEST Date: Fri, 28 Apr 2023 15:14:49 +0800 Subject: [PATCH 02/20] =?UTF-8?q?:construction:=20=E8=BF=9B=E8=A1=8C?= =?UTF-8?q?=E4=B8=AD=EF=BC=8Cedb=20schema=E6=9B=B4=E6=96=B0=E9=80=BB?= =?UTF-8?q?=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- edb/lib/sys.edgeql | 4 +- edb/pgsql/dbops/namespace.py | 14 +++- edb/pgsql/delta.py | 94 +++------------------- edb/pgsql/metaschema.py | 148 +++++++++++++++++------------------ edb/schema/namespace.py | 15 +++- 5 files changed, 113 insertions(+), 162 deletions(-) diff --git a/edb/lib/sys.edgeql b/edb/lib/sys.edgeql index fc371c10562..9f9f0a152c4 100644 --- a/edb/lib/sys.edgeql +++ b/edb/lib/sys.edgeql @@ -37,7 +37,9 @@ CREATE TYPE sys::Database EXTENDING sys::SystemObject { }; }; -CREATE TYPE sys::NameSpace EXTENDING sys::SystemObject; +CREATE TYPE sys::NameSpace EXTENDING sys::SystemObject { + CREATE REQUIRED SINGLE LINK db -> sys::Database; +}; CREATE TYPE sys::ExtensionPackage EXTENDING sys::SystemObject { diff --git a/edb/pgsql/dbops/namespace.py b/edb/pgsql/dbops/namespace.py index 86a6df68d06..fcc8706046e 100644 --- a/edb/pgsql/dbops/namespace.py +++ b/edb/pgsql/dbops/namespace.py @@ -19,16 +19,20 @@ from __future__ import annotations +from typing import Optional, Mapping, Any + from . import base from . import ddl from ..common import quote_ident as qi -class NameSpace: +class NameSpace(base.DBObject): def __init__( self, name: str, + metadata: Optional[Mapping[str, Any]] = None, ): + super().__init__(metadata=metadata) self.name = name def get_type(self): @@ -41,6 +45,14 @@ def is_shared(self) -> bool: return False +class CreateNameSpace(ddl.CreateObject, ddl.NonTransactionalDDLOperation): + def __init__(self, object, **kwargs): + super().__init__(object, **kwargs) + + def code(self, block: base.PLBlock) -> str: + return '' + + class DropNameSpace( ddl.SchemaObjectOperation, ddl.NonTransactionalDDLOperation diff --git a/edb/pgsql/delta.py b/edb/pgsql/delta.py index 0bed205d171..4da79b55780 100644 --- a/edb/pgsql/delta.py +++ b/edb/pgsql/delta.py @@ -6850,51 +6850,19 @@ def apply( context: sd.CommandContext, ) -> s_schema.Schema: schema = super().apply(schema, context) - - ns_id = str(self.scls.id) - name__internal = str(self.scls.get_name(schema)) - name = self.scls.get_displayname(schema) - - metadata = { - ns_id: { - 'id': ns_id, - 'name': name, - 'name__internal': name__internal, - 'builtin': self.scls.get_builtin(schema), - 'internal': self.scls.get_internal(schema), - } - } - - ctx_backend_params = context.backend_runtime_params - if ctx_backend_params is not None: - backend_params = cast( - params.BackendRuntimeParams, ctx_backend_params - ) - else: - backend_params = params.get_default_runtime_params() - - if backend_params.has_create_database: - tenant_id = self._get_tenant_id(context) - tpl_db_name = common.get_database_backend_name( - edbdef.EDGEDB_TEMPLATE_DB, tenant_id=tenant_id - ) - - self.pgops.add( - dbops.UpdateMetadataSection( - dbops.Database(name=tpl_db_name), - section='NameSpace', - metadata=metadata - ) - ) - else: - self.pgops.add( - dbops.UpdateSingleDBMetadataSection( - edbdef.EDGEDB_TEMPLATE_DB, - section='NameSpace', - metadata=metadata - ) + self.pgops.add( + dbops.CreateNameSpace( + dbops.NameSpace( + str(self.classname), + metadata=dict( + id=str(self.scls.id), + builtin=self.get_attribute_value('builtin'), + name=str(self.classname), + db=self.get_attribute_value('db') + ), + ), ) - + ) return schema @@ -6905,45 +6873,7 @@ def apply( context: sd.CommandContext, ) -> s_schema.Schema: schema = super().apply(schema, context) - ns_id = str(self.scls.id) - ctx_backend_params = context.backend_runtime_params - if ctx_backend_params is not None: - backend_params = cast( - params.BackendRuntimeParams, ctx_backend_params - ) - else: - backend_params = params.get_default_runtime_params() - - tpl_db_name = common.get_database_backend_name( - edbdef.EDGEDB_TEMPLATE_DB, tenant_id=backend_params.tenant_id - ) - - metadata = { - ns_id: None - } - if backend_params.has_create_database: - self.pgops.add( - dbops.UpdateMetadataSection( - dbops.Database(name=tpl_db_name), - section='NameSpace', - metadata=metadata - ) - ) - else: - self.pgops.add( - dbops.UpdateSingleDBMetadataSection( - edbdef.EDGEDB_TEMPLATE_DB, - section='NameSpace', - metadata=metadata - ) - ) - - self.pgops.add( - dbops.DropNameSpace( - dbops.NameSpace(str(self.classname)) - ) - ) return schema diff --git a/edb/pgsql/metaschema.py b/edb/pgsql/metaschema.py index eb5064fbc64..2f75ac08812 100644 --- a/edb/pgsql/metaschema.py +++ b/edb/pgsql/metaschema.py @@ -4806,90 +4806,85 @@ def _generate_schema_ver_views(schema: s_schema.Schema) -> List[dbops.View]: def _generate_namespace_views(schema: s_schema.Schema) -> List[dbops.View]: NameSpace = schema.get('sys::NameSpace', type=s_objtypes.ObjectType) annos = NameSpace.getptr( - schema, s_name.UnqualName('annotations'), type=s_links.Link) + schema, s_name.UnqualName('annotations'), type=s_links.Link + ) int_annos = NameSpace.getptr( - schema, s_name.UnqualName('annotations__internal'), type=s_links.Link) + schema, s_name.UnqualName('annotations__internal'), type=s_links.Link + ) view_query = f''' - SELECT - ((d.description)->>'id')::uuid - AS {qi(ptr_col_name(schema, NameSpace, 'id'))}, - (SELECT id FROM edgedb."_SchemaObjectType" - WHERE name = 'sys::NameSpace') - AS {qi(ptr_col_name(schema, NameSpace, '__type__'))}, - False AS {qi(ptr_col_name(schema, NameSpace, 'internal'))}, - (d.description)->>'name' - AS {qi(ptr_col_name(schema, NameSpace, 'name'))}, - (d.description)->>'name' - AS {qi(ptr_col_name(schema, NameSpace, 'name__internal'))}, - ARRAY[]::text[] - AS {qi(ptr_col_name(schema, NameSpace, 'computed_fields'))}, - ((d.description)->>'builtin')::bool - AS {qi(ptr_col_name(schema, NameSpace, 'builtin'))}, - (d.description)->>'module_name' - AS {qi(ptr_col_name(schema, NameSpace, 'module_name'))}, - ((d.description)->>'external')::bool - AS {qi(ptr_col_name(schema, NameSpace, 'external'))} - FROM - pg_database dat - CROSS JOIN LATERAL ( - SELECT - edgedb.shobj_metadata(dat.oid, 'pg_database') - AS description - ) AS d - WHERE - (d.description)->>'id' IS NOT NULL - AND (d.description)->>'tenant_id' = edgedb.get_backend_tenant_id() - ''' + SELECT + (ns.value->>'id')::uuid + AS {qi(ptr_col_name(schema, NameSpace, 'id'))}, + (SELECT id FROM edgedb."_SchemaObjectType" + WHERE name = 'sys::NameSpace') + AS {qi(ptr_col_name(schema, NameSpace, '__type__'))}, + (ns.value->>'name') + AS {qi(ptr_col_name(schema, NameSpace, 'name'))}, + (ns.value->>'name__internal') + AS {qi(ptr_col_name(schema, NameSpace, 'name__internal'))}, + ARRAY[]::text[] + AS {qi(ptr_col_name(schema, NameSpace, 'computed_fields'))}, + (ns.value->>'builtin')::bool + AS {qi(ptr_col_name(schema, NameSpace, 'builtin'))}, + (ns.value->>'internal')::bool + AS {qi(ptr_col_name(schema, NameSpace, 'internal'))}, + (ns.value->>'module_name') + AS {qi(ptr_col_name(schema, NameSpace, 'module_name'))}, + ((ns.value )->>'external')::bool + AS {qi(ptr_col_name(schema, NameSpace, 'external'))} + FROM + jsonb_each( + edgedb.get_database_metadata( + {ql(defines.EDGEDB_SYSTEM_DB)} + ) -> 'NameSpace' + ) AS ns + ''' annos_link_query = f''' - SELECT - ((d.description)->>'id')::uuid - AS {qi(ptr_col_name(schema, annos, 'source'))}, - (annotations->>'id')::uuid - AS {qi(ptr_col_name(schema, annos, 'target'))}, - (annotations->>'value')::text - AS {qi(ptr_col_name(schema, annos, 'value'))}, - (annotations->>'owned')::bool - AS {qi(ptr_col_name(schema, annos, 'owned'))} - FROM - pg_database dat - CROSS JOIN LATERAL ( - SELECT - edgedb.shobj_metadata(dat.oid, 'pg_database') - AS description - ) AS d - CROSS JOIN LATERAL - ROWS FROM ( - jsonb_array_elements((d.description)->'annotations') - ) AS annotations - ''' + SELECT + (ns.value->>'id')::uuid + AS {qi(ptr_col_name(schema, annos, 'source'))}, + (annotations->>'id')::uuid + AS {qi(ptr_col_name(schema, annos, 'target'))}, + (annotations->>'value')::text + AS {qi(ptr_col_name(schema, annos, 'value'))}, + (annotations->>'is_owned')::bool + AS {qi(ptr_col_name(schema, annos, 'owned'))} + FROM + jsonb_each( + edgedb.get_database_metadata( + {ql(defines.EDGEDB_SYSTEM_DB)} + ) -> 'NameSpace' + ) AS ns + CROSS JOIN LATERAL + ROWS FROM ( + jsonb_array_elements(ns.value->'annotations') + ) AS annotations + ''' int_annos_link_query = f''' - SELECT - ((d.description)->>'id')::uuid - AS {qi(ptr_col_name(schema, int_annos, 'source'))}, - (annotations->>'id')::uuid - AS {qi(ptr_col_name(schema, int_annos, 'target'))}, - (annotations->>'owned')::bool - AS {qi(ptr_col_name(schema, int_annos, 'owned'))} - FROM - pg_database dat - CROSS JOIN LATERAL ( - SELECT - edgedb.shobj_metadata(dat.oid, 'pg_database') - AS description - ) AS d - CROSS JOIN LATERAL - ROWS FROM ( - jsonb_array_elements( - (d.description)->'annotations__internal' - ) - ) AS annotations - ''' + SELECT + (ns.value->>'id')::uuid + AS {qi(ptr_col_name(schema, int_annos, 'source'))}, + (annotations->>'id')::uuid + AS {qi(ptr_col_name(schema, int_annos, 'target'))}, + (annotations->>'is_owned')::bool + AS {qi(ptr_col_name(schema, int_annos, 'owned'))} + FROM + jsonb_each( + edgedb.get_database_metadata( + {ql(defines.EDGEDB_SYSTEM_DB)} + ) -> 'NameSpace' + ) AS ns + CROSS JOIN LATERAL + ROWS FROM ( + jsonb_array_elements(ns.value->'annotations__internal') + ) AS annotations + ''' objects = { - Database: view_query, + NameSpace: view_query, annos: annos_link_query, int_annos: int_annos_link_query, } @@ -5077,6 +5072,9 @@ async def generate_support_views( for verview in _generate_schema_ver_views(schema): commands.add_command(dbops.CreateView(verview, or_replace=True)) + for nsview in _generate_namespace_views(schema): + commands.add_command(dbops.CreateView(nsview, or_replace=True)) + sys_alias_views = _generate_schema_alias_views( schema, s_name.UnqualName('sys')) for alias_view in sys_alias_views: diff --git a/edb/schema/namespace.py b/edb/schema/namespace.py index 734e5bc27bc..1b1106c9e0c 100644 --- a/edb/schema/namespace.py +++ b/edb/schema/namespace.py @@ -19,24 +19,25 @@ from __future__ import annotations +import uuid + from edb import errors from edb.edgeql import ast as qlast from edb.edgeql import qltypes -from . import abc as s_abc from . import annos as s_anno from . import delta as sd from . import objects as so from . import schema as s_schema +from . import database as s_database class NameSpace( so.ExternalObject, s_anno.AnnotationSubject, - s_abc.NameSpace, qlkind=qltypes.SchemaObjectClass.NAMESPACE, data_safe=False, ): - pass + db = so.SchemaField(s_database.Database) class NameSpaceCommandContext(sd.ObjectCommandContext[NameSpace]): @@ -73,6 +74,14 @@ def validate_create( super().validate_create(schema, context) self._validate_name(schema, context) + def _create_begin( + self, + schema: s_schema.Schema, + context: sd.CommandContext, + ) -> s_schema.Schema: + schema = super()._create_begin(schema, context) + return schema + class DeleteNameSpace(NameSpaceCommand, sd.DeleteExternalObject[NameSpace]): astnode = qlast.DropNameSpace From fe482cd3e7db3fdd5b643e09dccad6748a075c12 Mon Sep 17 00:00:00 2001 From: JMPRIEST Date: Sun, 30 Apr 2023 20:37:14 +0800 Subject: [PATCH 03/20] =?UTF-8?q?:construction:=20create/drop=20namespace?= =?UTF-8?q?=E5=90=8Epg=20schema=E7=9A=84=E5=88=9B=E5=BB=BA/=E5=88=A0?= =?UTF-8?q?=E9=99=A4=E9=80=BB=E8=BE=91=E5=AE=8C=E6=88=90=EF=BC=8Cnamespace?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=BA=93=E9=9A=94=E7=A6=BB=E6=80=A7=E7=9A=84?= =?UTF-8?q?meta=E6=9F=A5=E8=AF=A2=E9=80=BB=E8=BE=91=E5=AE=8C=E6=88=90?= =?UTF-8?q?=EF=BC=8C=E5=BE=85=E5=AE=8C=E5=96=84namespace=E5=AF=B9=E5=BA=94?= =?UTF-8?q?schema=E4=BD=BF=E7=94=A8=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- edb/errors/__init__.py | 4 ++++ edb/lib/sys.edgeql | 4 +--- edb/pgsql/dbops/namespace.py | 6 +++--- edb/pgsql/delta.py | 6 ++++-- edb/pgsql/metaschema.py | 28 ++++++++++++++-------------- edb/schema/namespace.py | 11 +---------- edb/server/compiler/compiler.py | 9 ++++++++- edb/server/compiler/errormech.py | 1 + edb/server/pgcon/errors.py | 1 + edb/server/protocol/binary.pyx | 3 +++ edb/server/protocol/execute.pyx | 4 ++-- 11 files changed, 42 insertions(+), 35 deletions(-) diff --git a/edb/errors/__init__.py b/edb/errors/__init__.py index 68847a18cf6..ee646805cce 100644 --- a/edb/errors/__init__.py +++ b/edb/errors/__init__.py @@ -309,6 +309,10 @@ class DuplicateCastDefinitionError(DuplicateDefinitionError): _code = 0x_04_05_02_0A +class DuplicateNameSpaceDefinitionError(DuplicateDefinitionError): + _code = 0x_04_05_02_0B + + class SessionTimeoutError(QueryError): _code = 0x_04_06_00_00 diff --git a/edb/lib/sys.edgeql b/edb/lib/sys.edgeql index 9f9f0a152c4..fc371c10562 100644 --- a/edb/lib/sys.edgeql +++ b/edb/lib/sys.edgeql @@ -37,9 +37,7 @@ CREATE TYPE sys::Database EXTENDING sys::SystemObject { }; }; -CREATE TYPE sys::NameSpace EXTENDING sys::SystemObject { - CREATE REQUIRED SINGLE LINK db -> sys::Database; -}; +CREATE TYPE sys::NameSpace EXTENDING sys::SystemObject; CREATE TYPE sys::ExtensionPackage EXTENDING sys::SystemObject { diff --git a/edb/pgsql/dbops/namespace.py b/edb/pgsql/dbops/namespace.py index fcc8706046e..fac1286c7fa 100644 --- a/edb/pgsql/dbops/namespace.py +++ b/edb/pgsql/dbops/namespace.py @@ -36,10 +36,10 @@ def __init__( self.name = name def get_type(self): - return 'NAMESPACE' + return 'SCHEMA' def get_id(self): - return qi(self.name) + return qi(f"{self.name}_edgedb") def is_shared(self) -> bool: return False @@ -61,7 +61,7 @@ class DropNameSpace( def code(self, block: base.PLBlock) -> str: schemas = ",".join( [ - qi(self.name.name + "_" + schema) + qi(f"{self.name}_{schema}") for schema in ['edgedb', 'edgedbss', 'edgedbpub', 'edgedbstd', 'edgedbinstdata', 'edgedbext'] ] ) diff --git a/edb/pgsql/delta.py b/edb/pgsql/delta.py index 4da79b55780..cc03a5653ef 100644 --- a/edb/pgsql/delta.py +++ b/edb/pgsql/delta.py @@ -6858,7 +6858,7 @@ def apply( id=str(self.scls.id), builtin=self.get_attribute_value('builtin'), name=str(self.classname), - db=self.get_attribute_value('db') + internal=False ), ), ) @@ -6873,7 +6873,9 @@ def apply( context: sd.CommandContext, ) -> s_schema.Schema: schema = super().apply(schema, context) - + self.pgops.add( + dbops.DropNameSpace(self.classname) + ) return schema diff --git a/edb/pgsql/metaschema.py b/edb/pgsql/metaschema.py index 2f75ac08812..72197f3fef0 100644 --- a/edb/pgsql/metaschema.py +++ b/edb/pgsql/metaschema.py @@ -4814,31 +4814,31 @@ def _generate_namespace_views(schema: s_schema.Schema) -> List[dbops.View]: view_query = f''' SELECT - (ns.value->>'id')::uuid + (ns.description->>'id')::uuid AS {qi(ptr_col_name(schema, NameSpace, 'id'))}, (SELECT id FROM edgedb."_SchemaObjectType" WHERE name = 'sys::NameSpace') AS {qi(ptr_col_name(schema, NameSpace, '__type__'))}, - (ns.value->>'name') + (ns.description->>'name') AS {qi(ptr_col_name(schema, NameSpace, 'name'))}, - (ns.value->>'name__internal') + (ns.description->>'name__internal') AS {qi(ptr_col_name(schema, NameSpace, 'name__internal'))}, ARRAY[]::text[] AS {qi(ptr_col_name(schema, NameSpace, 'computed_fields'))}, - (ns.value->>'builtin')::bool + (ns.description->>'builtin')::bool AS {qi(ptr_col_name(schema, NameSpace, 'builtin'))}, - (ns.value->>'internal')::bool + (ns.description->>'internal')::bool AS {qi(ptr_col_name(schema, NameSpace, 'internal'))}, - (ns.value->>'module_name') + (ns.description->>'module_name') AS {qi(ptr_col_name(schema, NameSpace, 'module_name'))}, - ((ns.value )->>'external')::bool + ((ns.description)->>'external')::bool AS {qi(ptr_col_name(schema, NameSpace, 'external'))} FROM - jsonb_each( - edgedb.get_database_metadata( - {ql(defines.EDGEDB_SYSTEM_DB)} - ) -> 'NameSpace' - ) AS ns + information_schema.schemata as s + CROSS JOIN LATERAL ( + select edgedb.obj_metadata(s.schema_name::regnamespace, 'pg_namespace') as DESCRIPTION + ) as ns + where ns.description ->> 'id' is not null ''' annos_link_query = f''' @@ -4854,7 +4854,7 @@ def _generate_namespace_views(schema: s_schema.Schema) -> List[dbops.View]: FROM jsonb_each( edgedb.get_database_metadata( - {ql(defines.EDGEDB_SYSTEM_DB)} + current_database() ) -> 'NameSpace' ) AS ns CROSS JOIN LATERAL @@ -4874,7 +4874,7 @@ def _generate_namespace_views(schema: s_schema.Schema) -> List[dbops.View]: FROM jsonb_each( edgedb.get_database_metadata( - {ql(defines.EDGEDB_SYSTEM_DB)} + current_database() ) -> 'NameSpace' ) AS ns CROSS JOIN LATERAL diff --git a/edb/schema/namespace.py b/edb/schema/namespace.py index 1b1106c9e0c..781f0dbf9f4 100644 --- a/edb/schema/namespace.py +++ b/edb/schema/namespace.py @@ -28,7 +28,6 @@ from . import delta as sd from . import objects as so from . import schema as s_schema -from . import database as s_database class NameSpace( @@ -37,7 +36,7 @@ class NameSpace( qlkind=qltypes.SchemaObjectClass.NAMESPACE, data_safe=False, ): - db = so.SchemaField(s_database.Database) + pass class NameSpaceCommandContext(sd.ObjectCommandContext[NameSpace]): @@ -74,14 +73,6 @@ def validate_create( super().validate_create(schema, context) self._validate_name(schema, context) - def _create_begin( - self, - schema: s_schema.Schema, - context: sd.CommandContext, - ) -> s_schema.Schema: - schema = super()._create_begin(schema, context) - return schema - class DeleteNameSpace(NameSpaceCommand, sd.DeleteExternalObject[NameSpace]): astnode = qlast.DropNameSpace diff --git a/edb/server/compiler/compiler.py b/edb/server/compiler/compiler.py index cad93ee7338..167209a4feb 100644 --- a/edb/server/compiler/compiler.py +++ b/edb/server/compiler/compiler.py @@ -56,6 +56,7 @@ from edb.ir import ast as irast from edb.schema import database as s_db +from edb.schema import namespace as s_ns from edb.schema import extensions as s_ext from edb.schema import roles as s_roles from edb.schema import ddl as s_ddl @@ -409,7 +410,12 @@ def _process_delta(self, ctx: CompileContext, delta): for c in pgdelta.get_subcommands() ) - if db_cmd: + ns_cmd = any( + isinstance(c, s_ns.NameSpaceCommand) + for c in pgdelta.get_subcommands() + ) + + if db_cmd or ns_cmd: block = pg_dbops.SQLBlock() new_be_types = new_types = frozenset() else: @@ -433,6 +439,7 @@ def may_has_backend_id(_id): # schema persistence asynchronizable schema_peristence_async = ( not db_cmd + and not ns_cmd and not new_be_types and not any( isinstance(c, (s_ext.ExtensionCommand, diff --git a/edb/server/compiler/errormech.py b/edb/server/compiler/errormech.py index 6291a0bac88..cd770bc2bd6 100644 --- a/edb/server/compiler/errormech.py +++ b/edb/server/compiler/errormech.py @@ -83,6 +83,7 @@ class ErrorDetails(NamedTuple): pgerrors.ERROR_INVALID_CATALOG_NAME: errors.UnknownDatabaseError, pgerrors.ERROR_OBJECT_IN_USE: errors.ExecutionError, pgerrors.ERROR_DUPLICATE_DATABASE: errors.DuplicateDatabaseDefinitionError, + pgerrors.ERROR_DUPLICATE_SCHEMA: errors.DuplicateNameSpaceDefinitionError, pgerrors.ERROR_IDLE_IN_TRANSACTION_TIMEOUT: errors.IdleTransactionTimeoutError, pgerrors.ERROR_QUERY_CANCELLED: errors.QueryTimeoutError, diff --git a/edb/server/pgcon/errors.py b/edb/server/pgcon/errors.py index c1dadd885c0..240c3f6bead 100644 --- a/edb/server/pgcon/errors.py +++ b/edb/server/pgcon/errors.py @@ -60,6 +60,7 @@ ERROR_WRONG_OBJECT_TYPE = '42809' ERROR_INSUFFICIENT_PRIVILEGE = '42501' ERROR_DUPLICATE_DATABASE = '42P04' +ERROR_DUPLICATE_SCHEMA = '42P06' ERROR_PROGRAM_LIMIT_EXCEEDED = '54000' diff --git a/edb/server/protocol/binary.pyx b/edb/server/protocol/binary.pyx index 00a5e52798d..ac5ed2c7fc1 100644 --- a/edb/server/protocol/binary.pyx +++ b/edb/server/protocol/binary.pyx @@ -1584,6 +1584,9 @@ cdef class EdgeConnection(frontend.FrontendConnection): tenant_id = self.server.get_tenant_id() message = static_exc.args[0].replace(f'{tenant_id}_', '') exc = type(static_exc)(message) + elif isinstance(static_exc, errors.DuplicateNameSpaceDefinitionError): + message = static_exc.args[0].replace('schema', 'namespace').replace('_edgedbext', '') + exc = type(static_exc)(message) else: exc = static_exc diff --git a/edb/server/protocol/execute.pyx b/edb/server/protocol/execute.pyx index e21c2185106..a806eaabf56 100644 --- a/edb/server/protocol/execute.pyx +++ b/edb/server/protocol/execute.pyx @@ -82,10 +82,10 @@ async def execute( ) if query_unit.drop_db: await server._on_before_drop_db(query_unit.drop_db, dbv.dbname) - if query_unit.system_config: - await execute_system_config(be_conn, dbv, query_unit) if query_unit.create_ns: await server.create_namespace(be_conn, query_unit.create_ns) + if query_unit.system_config: + await execute_system_config(be_conn, dbv, query_unit) else: config_ops = query_unit.config_ops From 23731261b1d326ee370658733b008fba803240ac Mon Sep 17 00:00:00 2001 From: JMPRIEST Date: Sun, 30 Apr 2023 20:52:52 +0800 Subject: [PATCH 04/20] =?UTF-8?q?:art:=20=E5=AE=8C=E5=96=84namespace?= =?UTF-8?q?=E7=9B=B8=E5=85=B3=E7=9A=84=E6=8A=A5=E9=94=99=E4=BF=A1=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- edb/errors/__init__.py | 4 ++++ edb/pgsql/dbops/namespace.py | 4 ++-- edb/server/compiler/errormech.py | 1 + edb/server/pgcon/errors.py | 1 + edb/server/protocol/binary.pyx | 4 +++- 5 files changed, 11 insertions(+), 3 deletions(-) diff --git a/edb/errors/__init__.py b/edb/errors/__init__.py index ee646805cce..4fcf1d9ec54 100644 --- a/edb/errors/__init__.py +++ b/edb/errors/__init__.py @@ -213,6 +213,10 @@ class UnknownParameterError(InvalidReferenceError): _code = 0x_04_03_00_06 +class UnknownSchemaError(InvalidReferenceError): + _code = 0x_04_03_00_07 + + class SchemaError(QueryError): _code = 0x_04_04_00_00 diff --git a/edb/pgsql/dbops/namespace.py b/edb/pgsql/dbops/namespace.py index fac1286c7fa..2b1d87ff7c3 100644 --- a/edb/pgsql/dbops/namespace.py +++ b/edb/pgsql/dbops/namespace.py @@ -62,7 +62,7 @@ def code(self, block: base.PLBlock) -> str: schemas = ",".join( [ qi(f"{self.name}_{schema}") - for schema in ['edgedb', 'edgedbss', 'edgedbpub', 'edgedbstd', 'edgedbinstdata', 'edgedbext'] + for schema in ['edgedbext', 'edgedb', 'edgedbss', 'edgedbpub', 'edgedbstd', 'edgedbinstdata', ] ] ) - return f'DROP SCHEMA IF EXISTS {schemas} CASCADE;' + return f'DROP SCHEMA {schemas} CASCADE;' diff --git a/edb/server/compiler/errormech.py b/edb/server/compiler/errormech.py index cd770bc2bd6..78aea17faeb 100644 --- a/edb/server/compiler/errormech.py +++ b/edb/server/compiler/errormech.py @@ -81,6 +81,7 @@ class ErrorDetails(NamedTuple): pgerrors.ERROR_SERIALIZATION_FAILURE: errors.TransactionSerializationError, pgerrors.ERROR_DEADLOCK_DETECTED: errors.TransactionDeadlockError, pgerrors.ERROR_INVALID_CATALOG_NAME: errors.UnknownDatabaseError, + pgerrors.ERROR_INVALID_SCHEMA_NAME: errors.UnknownSchemaError, pgerrors.ERROR_OBJECT_IN_USE: errors.ExecutionError, pgerrors.ERROR_DUPLICATE_DATABASE: errors.DuplicateDatabaseDefinitionError, pgerrors.ERROR_DUPLICATE_SCHEMA: errors.DuplicateNameSpaceDefinitionError, diff --git a/edb/server/pgcon/errors.py b/edb/server/pgcon/errors.py index 240c3f6bead..c3922b826f3 100644 --- a/edb/server/pgcon/errors.py +++ b/edb/server/pgcon/errors.py @@ -53,6 +53,7 @@ ERROR_INVALID_PASSWORD = '28P01' ERROR_INVALID_CATALOG_NAME = '3D000' +ERROR_INVALID_SCHEMA_NAME = '3F000' ERROR_SERIALIZATION_FAILURE = '40001' ERROR_DEADLOCK_DETECTED = '40P01' diff --git a/edb/server/protocol/binary.pyx b/edb/server/protocol/binary.pyx index ac5ed2c7fc1..1c6536e6746 100644 --- a/edb/server/protocol/binary.pyx +++ b/edb/server/protocol/binary.pyx @@ -1584,7 +1584,9 @@ cdef class EdgeConnection(frontend.FrontendConnection): tenant_id = self.server.get_tenant_id() message = static_exc.args[0].replace(f'{tenant_id}_', '') exc = type(static_exc)(message) - elif isinstance(static_exc, errors.DuplicateNameSpaceDefinitionError): + elif isinstance(static_exc, + (errors.DuplicateNameSpaceDefinitionError, errors.UnknownSchemaError) + ): message = static_exc.args[0].replace('schema', 'namespace').replace('_edgedbext', '') exc = type(static_exc)(message) else: From 0b3ba35ee522c2dc1ec0f00e1b614bee7e7aa208 Mon Sep 17 00:00:00 2001 From: JMPRIEST Date: Thu, 11 May 2023 20:01:41 +0800 Subject: [PATCH 05/20] =?UTF-8?q?:sparkles:=20dump=20&=20restore=E6=94=AF?= =?UTF-8?q?=E6=8C=81=E5=A4=96=E9=83=A8=E8=A1=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- edb/pgsql/delta.py | 5 +++ edb/schema/ddl.py | 1 + edb/schema/delta.py | 2 ++ edb/server/compiler/compiler.py | 63 ++++++++++++++++++++++++++------- edb/server/protocol/binary.pyx | 44 ++++++++++++++++++++++- edb/server/protocol/consts.pxi | 3 ++ edb/server/protocol/execute.pyx | 2 +- 7 files changed, 105 insertions(+), 15 deletions(-) diff --git a/edb/pgsql/delta.py b/edb/pgsql/delta.py index 0f612602311..f34d1be3455 100644 --- a/edb/pgsql/delta.py +++ b/edb/pgsql/delta.py @@ -6761,6 +6761,11 @@ def collect_external_objects( f"Failed to find view definition for external object {key} ") view_def = context.external_view[key] + if context.restoring_external: + self.external_views.append(dbops.View(query=view_def, name=('edgedbpub', str(obj.id)))) + self.external_views.append(dbops.View(query=view_def, name=('edgedbpub', str(obj.id) + '_t'))) + return + columns = [] join_link_table = None source_identity = None diff --git a/edb/schema/ddl.py b/edb/schema/ddl.py index 5edb0f7a9b0..8a4d2e4ecdc 100644 --- a/edb/schema/ddl.py +++ b/edb/schema/ddl.py @@ -256,6 +256,7 @@ def _filter(schema: s_schema.Schema, obj: so.Object) -> bool: s_mod.Module, s_func.Parameter, s_pseudo.PseudoType, + s_migr.Migration ) schemaclasses = [ diff --git a/edb/schema/delta.py b/edb/schema/delta.py index 8efd343d28a..1c7f035210f 100644 --- a/edb/schema/delta.py +++ b/edb/schema/delta.py @@ -1218,6 +1218,7 @@ def __init__( module: Optional[str] = None, module_is_implicit: Optional[bool] = False, external_view: Optional[Mapping] = None, + restoring_external: Optional[bool] = False, ) -> None: self.stack: List[CommandContextToken[Command]] = [] self._cache: Dict[Hashable, Any] = {} @@ -1247,6 +1248,7 @@ def __init__( self.module = module self.module_is_implicit = module_is_implicit self.external_view = external_view or immutables.Map() + self.restoring_external = restoring_external self.external_objs = set() @property diff --git a/edb/server/compiler/compiler.py b/edb/server/compiler/compiler.py index 4f749aa2dbe..18774ff2502 100644 --- a/edb/server/compiler/compiler.py +++ b/edb/server/compiler/compiler.py @@ -42,7 +42,6 @@ from edb import edgeql from edb.common import debug from edb.common import verutils -from edb.common import util from edb.common import uuidgen from edb.common import ast @@ -138,6 +137,8 @@ class CompileContext: in_tx: Optional[bool] = False # External view definition external_view: Optional[Mapping] = immutables.Map() + # If in restoring external view + restoring_external: Optional[bool] = False DEFAULT_MODULE_ALIASES_MAP = immutables.Map( @@ -382,6 +383,7 @@ def _new_delta_context(self, ctx: CompileContext): self.get_config_val(ctx, 'allow_dml_in_functions')) context.module = ctx.module context.external_view = ctx.external_view + context.restoring_external = ctx.restoring_external return context def _process_delta(self, ctx: CompileContext, delta): @@ -2423,6 +2425,7 @@ def compile( json_parameters: bool = False, module: Optional[str] = None, external_view: Optional[Mapping] = None, + restoring_external: Optional[bool] = False, ) -> Tuple[dbstate.QueryUnitGroup, Optional[dbstate.CompilerConnectionState]]: @@ -2467,7 +2470,8 @@ def compile( source=source, protocol_version=protocol_version, module=module, - external_view=external_view + external_view=external_view, + restoring_external=restoring_external ) unit_group = self._compile(ctx=ctx, source=source) @@ -2499,6 +2503,7 @@ def compile_in_tx( expect_rollback: bool = False, module: Optional[str] = None, external_view: Optional[Mapping] = None, + restoring_external: Optional[bool] = False, ) -> Tuple[dbstate.QueryUnitGroup, dbstate.CompilerConnectionState]: if ( expect_rollback and @@ -2528,7 +2533,8 @@ def compile_in_tx( expect_rollback=expect_rollback, module=module, in_tx=True, - external_view=external_view + external_view=external_view, + restoring_external=restoring_external ) return self._compile(ctx=ctx, source=source), ctx.state @@ -2549,7 +2555,8 @@ def describe_database_dump( config_ddl = config.to_edgeql(config.get_settings(), database_config) schema_ddl = s_ddl.ddl_text_from_schema( - schema, include_migrations=True) + schema, include_migrations=True + ) all_objects = schema.get_objects( exclude_stdlib=True, @@ -2575,6 +2582,7 @@ def describe_database_dump( objtypes = schema.get_objects( type=s_objtypes.ObjectType, exclude_stdlib=True, + extra_filters=[lambda s, o: not o.get_external(s)] ) descriptors = [] @@ -2594,11 +2602,25 @@ def describe_database_dump( f'SELECT edgedb._dump_sequences(ARRAY[{seq_ids}]::uuid[])' ) + external_ids = [] + for obj in schema.get_objects( + extra_filters=[lambda s, o: o.get_external(s) and (pg_delta.has_table(o, s))] + ): + if isinstance(obj, s_links.Link): + external_ids.append( + ((obj.get_source_type(schema).get_displayname(schema), obj.get_displayname(schema)), str(obj.id)) + ) + else: + external_ids.append( + ((obj.get_displayname(schema)), str(obj.id)) + ) + return DumpDescriptor( schema_ddl=config_ddl + '\n' + schema_ddl, schema_dynamic_ddl=tuple(dynamic_ddl), schema_ids=ids, blocks=descriptors, + external_ids=external_ids ) def infer_expr( @@ -2791,6 +2813,7 @@ def describe_database_restore( schema_ids: List[Tuple[str, str, bytes]], blocks: List[Tuple[bytes, bytes]], # type_id, typespec protocol_version: Tuple[int, int], + external_view: Dict[str, str] ) -> RestoreDescriptor: schema_object_ids = { ( @@ -2832,15 +2855,28 @@ def describe_database_restore( cached_reflection=EMPTY_MAP, ) - ctx = CompileContext( - state=state, - output_format=enums.OutputFormat.BINARY, - expected_cardinality_one=False, - compat_ver=dump_server_ver, - schema_object_ids=schema_object_ids, - log_ddl_as_migrations=False, - protocol_version=protocol_version, - ) + if external_view: + ctx = CompileContext( + state=state, + output_format=enums.OutputFormat.BINARY, + expected_cardinality_one=False, + compat_ver=dump_server_ver, + schema_object_ids=schema_object_ids, + log_ddl_as_migrations=False, + protocol_version=protocol_version, + external_view=external_view, + restoring_external=True + ) + else: + ctx = CompileContext( + state=state, + output_format=enums.OutputFormat.BINARY, + expected_cardinality_one=False, + compat_ver=dump_server_ver, + schema_object_ids=schema_object_ids, + log_ddl_as_migrations=False, + protocol_version=protocol_version, + ) ctx.state.start_tx() @@ -3088,6 +3124,7 @@ class DumpDescriptor(NamedTuple): schema_dynamic_ddl: Tuple[str] schema_ids: List[Tuple[str, str, bytes]] blocks: Sequence[DumpBlockDescriptor] + external_ids: List[Tuple] class DumpBlockDescriptor(NamedTuple): diff --git a/edb/server/protocol/binary.pyx b/edb/server/protocol/binary.pyx index 0d7400d3904..7a03b569973 100644 --- a/edb/server/protocol/binary.pyx +++ b/edb/server/protocol/binary.pyx @@ -1845,7 +1845,7 @@ cdef class EdgeConnection(frontend.FrontendConnection): db_config = await server.introspect_db_config(pgcon) dump_protocol = self.max_protocol - schema_ddl, schema_dynamic_ddl, schema_ids, blocks = ( + schema_ddl, schema_dynamic_ddl, schema_ids, blocks, external_ids = ( await compiler_pool.describe_database_dump( user_schema, global_schema, @@ -1872,6 +1872,20 @@ cdef class EdgeConnection(frontend.FrontendConnection): msg_buf.write_int16(dump_protocol[0]) msg_buf.write_int16(dump_protocol[1]) + + # adding external ddl & external ids + external_views = await self.external_views(external_ids, pgcon) + msg_buf.write_int32(len(external_views)) + for name, view_sql in external_views: + if isinstance(name, tuple): + msg_buf.write_int16(DUMP_EXTERNAL_KEY_LINK) + msg_buf.write_len_prefixed_utf8(name[0]) + msg_buf.write_len_prefixed_utf8(name[1]) + else: + msg_buf.write_int16(DUMP_EXTERNAL_KEY_OBJ) + msg_buf.write_len_prefixed_utf8(name) + msg_buf.write_len_prefixed_utf8(view_sql) + msg_buf.write_len_prefixed_utf8(schema_ddl) msg_buf.write_int32(len(schema_ids)) @@ -1953,6 +1967,15 @@ cdef class EdgeConnection(frontend.FrontendConnection): self.write(msg_buf.end_message()) self.flush() + async def external_views(self, external_ids: List[Tuple[str, str]], pgcon): + views = [] + for ext_name, ext_id in external_ids: + view = f"SELECT view_definition FROM information_schema.views WHERE table_name = '{ext_id}';" + view_sql = await pgcon.sql_fetch_val(view.encode('utf-8')) + if view_sql is not None: + views.append((ext_name, view_sql.decode('utf-8'))) + return views + async def _execute_utility_stmt(self, eql: str, pgcon): cdef dbview.DatabaseConnectionView _dbview @@ -2022,6 +2045,24 @@ cdef class EdgeConnection(frontend.FrontendConnection): raise errors.ProtocolError( f'unsupported dump version {proto_major}.{proto_minor}') + # getting external ddl & external ids + external_view_num = self.buffer.read_int32() + logger.info(external_view_num) + external_views = [] + for _ in range(external_view_num): + key_flag = self.buffer.read_int16() + if key_flag == DUMP_EXTERNAL_KEY_LINK: + obj_name = self.buffer.read_len_prefixed_utf8() + link_name = self.buffer.read_len_prefixed_utf8() + sql = self.buffer.read_len_prefixed_utf8() + logger.info(obj_name, link_name, sql) + external_views.append(((obj_name, link_name), sql)) + else: + name = self.buffer.read_len_prefixed_utf8() + sql = self.buffer.read_len_prefixed_utf8() + logger.info(name, sql) + external_views.append((name, sql)) + schema_ddl = self.buffer.read_len_prefixed_bytes() ids_num = self.buffer.read_int32() @@ -2076,6 +2117,7 @@ cdef class EdgeConnection(frontend.FrontendConnection): schema_ids, blocks, proto, + dict(external_views) ) for query_unit in schema_sql_units: diff --git a/edb/server/protocol/consts.pxi b/edb/server/protocol/consts.pxi index e74127c523d..c9a5ddb8f96 100644 --- a/edb/server/protocol/consts.pxi +++ b/edb/server/protocol/consts.pxi @@ -30,3 +30,6 @@ DEF DUMP_HEADER_BLOCKS_INFO = 104 DEF DUMP_HEADER_BLOCK_ID = 110 DEF DUMP_HEADER_BLOCK_NUM = 111 DEF DUMP_HEADER_BLOCK_DATA = 112 + +DEF DUMP_EXTERNAL_KEY_OBJ = b'O' +DEF DUMP_EXTERNAL_KEY_LINK = b'L' diff --git a/edb/server/protocol/execute.pyx b/edb/server/protocol/execute.pyx index 810aad07139..688f949119c 100644 --- a/edb/server/protocol/execute.pyx +++ b/edb/server/protocol/execute.pyx @@ -321,7 +321,7 @@ async def execute_script( side_effects & dbview.SideEffects.SchemaChanges and group_mutation is not None ): - dbv.save_schema_mutaion(group_mutation, gmut_unpickled) + dbv.save_schema_mutaion(gmut_unpickled, group_mutation) state = dbv.serialize_state() if state is not orig_state: From 044a0c8de3f8b1b3e7946da8ef7a72e2fd29233d Mon Sep 17 00:00:00 2001 From: JMPRIEST Date: Fri, 12 May 2023 13:25:56 +0800 Subject: [PATCH 06/20] =?UTF-8?q?:sparkles:=20binary=20v0=20=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0dump=20&=20restore=E6=94=AF=E6=8C=81=E5=A4=96=E9=83=A8?= =?UTF-8?q?=E8=A1=A8=20:bug:=20=E5=85=BC=E5=AE=B9=E5=8E=9F=E6=9C=89dump?= =?UTF-8?q?=E6=A0=BC=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- edb/server/protocol/binary.pyx | 48 ++++++++++--------- edb/server/protocol/binary_v0.pyx | 37 ++++++++++++-- edb/server/protocol/consts.pxi | 1 + ...tup.esdl => http_create_type_setup.edgeql} | 0 tests/test_http_create_type.py | 2 +- 5 files changed, 61 insertions(+), 27 deletions(-) rename tests/schemas/{http_create_type_setup.esdl => http_create_type_setup.edgeql} (100%) diff --git a/edb/server/protocol/binary.pyx b/edb/server/protocol/binary.pyx index 7a03b569973..b259685bb65 100644 --- a/edb/server/protocol/binary.pyx +++ b/edb/server/protocol/binary.pyx @@ -1862,7 +1862,7 @@ cdef class EdgeConnection(frontend.FrontendConnection): msg_buf = WriteBuffer.new_message(b'@') - msg_buf.write_int16(3) # number of headers + msg_buf.write_int16(4) # number of headers msg_buf.write_int16(DUMP_HEADER_BLOCK_TYPE) msg_buf.write_len_prefixed_bytes(DUMP_HEADER_BLOCK_TYPE_INFO) msg_buf.write_int16(DUMP_HEADER_SERVER_VER) @@ -1870,10 +1870,8 @@ cdef class EdgeConnection(frontend.FrontendConnection): msg_buf.write_int16(DUMP_HEADER_SERVER_TIME) msg_buf.write_len_prefixed_utf8(str(int(time.time()))) - msg_buf.write_int16(dump_protocol[0]) - msg_buf.write_int16(dump_protocol[1]) - # adding external ddl & external ids + msg_buf.write_int16(DUMP_EXTERNAL_VIEW) external_views = await self.external_views(external_ids, pgcon) msg_buf.write_int32(len(external_views)) for name, view_sql in external_views: @@ -1886,6 +1884,9 @@ cdef class EdgeConnection(frontend.FrontendConnection): msg_buf.write_len_prefixed_utf8(name) msg_buf.write_len_prefixed_utf8(view_sql) + msg_buf.write_int16(dump_protocol[0]) + msg_buf.write_int16(dump_protocol[1]) + msg_buf.write_len_prefixed_utf8(schema_ddl) msg_buf.write_int32(len(schema_ids)) @@ -2032,11 +2033,30 @@ cdef class EdgeConnection(frontend.FrontendConnection): dump_server_ver_str = None headers_num = self.buffer.read_int16() + external_views = [] for _ in range(headers_num): hdrname = self.buffer.read_int16() - hdrval = self.buffer.read_len_prefixed_bytes() + if hdrname != DUMP_EXTERNAL_VIEW: + hdrval = self.buffer.read_len_prefixed_bytes() if hdrname == DUMP_HEADER_SERVER_VER: dump_server_ver_str = hdrval.decode('utf-8') + # getting external ddl & external ids + if hdrname == DUMP_EXTERNAL_VIEW: + external_view_num = self.buffer.read_int32() + logger.info(external_view_num) + for _ in range(external_view_num): + key_flag = self.buffer.read_int16() + if key_flag == DUMP_EXTERNAL_KEY_LINK: + obj_name = self.buffer.read_len_prefixed_utf8() + link_name = self.buffer.read_len_prefixed_utf8() + sql = self.buffer.read_len_prefixed_utf8() + logger.info((obj_name, link_name, sql)) + external_views.append(((obj_name, link_name), sql)) + else: + name = self.buffer.read_len_prefixed_utf8() + sql = self.buffer.read_len_prefixed_utf8() + logger.info((name, sql)) + external_views.append((name, sql)) proto_major = self.buffer.read_int16() proto_minor = self.buffer.read_int16() @@ -2045,24 +2065,6 @@ cdef class EdgeConnection(frontend.FrontendConnection): raise errors.ProtocolError( f'unsupported dump version {proto_major}.{proto_minor}') - # getting external ddl & external ids - external_view_num = self.buffer.read_int32() - logger.info(external_view_num) - external_views = [] - for _ in range(external_view_num): - key_flag = self.buffer.read_int16() - if key_flag == DUMP_EXTERNAL_KEY_LINK: - obj_name = self.buffer.read_len_prefixed_utf8() - link_name = self.buffer.read_len_prefixed_utf8() - sql = self.buffer.read_len_prefixed_utf8() - logger.info(obj_name, link_name, sql) - external_views.append(((obj_name, link_name), sql)) - else: - name = self.buffer.read_len_prefixed_utf8() - sql = self.buffer.read_len_prefixed_utf8() - logger.info(name, sql) - external_views.append((name, sql)) - schema_ddl = self.buffer.read_len_prefixed_bytes() ids_num = self.buffer.read_int32() diff --git a/edb/server/protocol/binary_v0.pyx b/edb/server/protocol/binary_v0.pyx index 8e36f9e2dfe..43f6c0d1d88 100644 --- a/edb/server/protocol/binary_v0.pyx +++ b/edb/server/protocol/binary_v0.pyx @@ -317,7 +317,7 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): db_config = await server.introspect_db_config(pgcon) dump_protocol = self.max_protocol - schema_ddl, schema_dynamic_ddl, schema_ids, blocks = ( + schema_ddl, schema_dynamic_ddl, schema_ids, blocks, external_ids = ( await compiler_pool.describe_database_dump( user_schema, global_schema, @@ -334,7 +334,7 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): msg_buf = WriteBuffer.new_message(b'@') - msg_buf.write_int16(3) # number of headers + msg_buf.write_int16(4) # number of headers msg_buf.write_int16(DUMP_HEADER_BLOCK_TYPE) msg_buf.write_len_prefixed_bytes(DUMP_HEADER_BLOCK_TYPE_INFO) msg_buf.write_int16(DUMP_HEADER_SERVER_VER) @@ -342,6 +342,20 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): msg_buf.write_int16(DUMP_HEADER_SERVER_TIME) msg_buf.write_len_prefixed_utf8(str(int(time.time()))) + # adding external ddl & external ids + msg_buf.write_int16(DUMP_EXTERNAL_VIEW) + external_views = await self.external_views(external_ids, pgcon) + msg_buf.write_int32(len(external_views)) + for name, view_sql in external_views: + if isinstance(name, tuple): + msg_buf.write_int16(DUMP_EXTERNAL_KEY_LINK) + msg_buf.write_len_prefixed_utf8(name[0]) + msg_buf.write_len_prefixed_utf8(name[1]) + else: + msg_buf.write_int16(DUMP_EXTERNAL_KEY_OBJ) + msg_buf.write_len_prefixed_utf8(name) + msg_buf.write_len_prefixed_utf8(view_sql) + msg_buf.write_int16(dump_protocol[0]) msg_buf.write_int16(dump_protocol[1]) msg_buf.write_len_prefixed_utf8(schema_ddl) @@ -447,11 +461,27 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): dump_server_ver_str = None headers_num = self.buffer.read_int16() + external_views = [] for _ in range(headers_num): hdrname = self.buffer.read_int16() - hdrval = self.buffer.read_len_prefixed_bytes() + if hdrname != DUMP_EXTERNAL_VIEW: + hdrval = self.buffer.read_len_prefixed_bytes() if hdrname == DUMP_HEADER_SERVER_VER: dump_server_ver_str = hdrval.decode('utf-8') + # getting external ddl & external ids + if hdrname == DUMP_EXTERNAL_VIEW: + external_view_num = self.buffer.read_int32() + for _ in range(external_view_num): + key_flag = self.buffer.read_int16() + if key_flag == DUMP_EXTERNAL_KEY_LINK: + obj_name = self.buffer.read_len_prefixed_utf8() + link_name = self.buffer.read_len_prefixed_utf8() + sql = self.buffer.read_len_prefixed_utf8() + external_views.append(((obj_name, link_name), sql)) + else: + name = self.buffer.read_len_prefixed_utf8() + sql = self.buffer.read_len_prefixed_utf8() + external_views.append((name, sql)) proto_major = self.buffer.read_int16() proto_minor = self.buffer.read_int16() @@ -514,6 +544,7 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): schema_ids, blocks, proto, + dict(external_views) ) for query_unit in schema_sql_units: diff --git a/edb/server/protocol/consts.pxi b/edb/server/protocol/consts.pxi index c9a5ddb8f96..09cc41e1eee 100644 --- a/edb/server/protocol/consts.pxi +++ b/edb/server/protocol/consts.pxi @@ -26,6 +26,7 @@ DEF DUMP_HEADER_BLOCK_TYPE_DATA = b'D' DEF DUMP_HEADER_SERVER_TIME = 102 DEF DUMP_HEADER_SERVER_VER = 103 DEF DUMP_HEADER_BLOCKS_INFO = 104 +DEF DUMP_EXTERNAL_VIEW = 105 DEF DUMP_HEADER_BLOCK_ID = 110 DEF DUMP_HEADER_BLOCK_NUM = 111 diff --git a/tests/schemas/http_create_type_setup.esdl b/tests/schemas/http_create_type_setup.edgeql similarity index 100% rename from tests/schemas/http_create_type_setup.esdl rename to tests/schemas/http_create_type_setup.edgeql diff --git a/tests/test_http_create_type.py b/tests/test_http_create_type.py index b88be97deef..38340edff54 100644 --- a/tests/test_http_create_type.py +++ b/tests/test_http_create_type.py @@ -63,7 +63,7 @@ class TestHttpCreateType(tb.ExternTestCase): ) SETUP = os.path.join( os.path.dirname(__file__), 'schemas', - 'http_create_type_setup.esdl' + 'http_create_type_setup.edgeql' ) # EdgeQL/HTTP queries cannot run in a transaction From aecdac75b29b385a96bea29081d87c30050c38e8 Mon Sep 17 00:00:00 2001 From: JMPRIEST Date: Fri, 12 May 2023 18:05:35 +0800 Subject: [PATCH 07/20] =?UTF-8?q?:memo:=20=E5=A4=87=E4=BB=BD=E9=AA=8C?= =?UTF-8?q?=E8=AF=81v0=E9=80=BB=E8=BE=91=E7=9A=84script?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- edb/server/compiler_pool/worker.py | 14 ++ scripts/create_type.py | 185 +++++++++++++++++++ scripts/dump_restore_v0.py | 284 +++++++++++++++++++++++++++++ 3 files changed, 483 insertions(+) create mode 100644 scripts/create_type.py create mode 100644 scripts/dump_restore_v0.py diff --git a/edb/server/compiler_pool/worker.py b/edb/server/compiler_pool/worker.py index 8b222743e55..add80ad95a7 100644 --- a/edb/server/compiler_pool/worker.py +++ b/edb/server/compiler_pool/worker.py @@ -305,6 +305,20 @@ def try_compile_rollback( return COMPILER.try_compile_rollback(*compile_args, **compile_kwargs) +def describe_database_dump( + *compile_args: Any, + **compile_kwargs: Any, +): + return COMPILER.describe_database_dump(*compile_args, **compile_kwargs) + + +def describe_database_restore( + *compile_args: Any, + **compile_kwargs: Any, +): + return COMPILER.describe_database_restore(*compile_args, **compile_kwargs) + + def compile_graphql( dbname: str, user_schema: Optional[bytes], diff --git a/scripts/create_type.py b/scripts/create_type.py new file mode 100644 index 00000000000..54d30e94c20 --- /dev/null +++ b/scripts/create_type.py @@ -0,0 +1,185 @@ +from typing import Dict, List, NamedTuple +import requests + + +def as_dict(self=None) -> Dict: + payload = {} + for k, v in self._asdict().items(): + if v is None: + continue + if isinstance(v, list): + payload[k] = [e.as_dict() for e in v] + elif isinstance(v, (PropertyDetail, LinkDetail,)): + payload[k] = v.as_dict() + else: + payload[k] = v + return payload + + +class PropertyDetail(NamedTuple): + name: str + type: str = None + alias: str = None + cardinality: str = None + required: bool = None + expr: str = None + exclusive: bool = None + + as_dict = as_dict + + +class LinkDetail(NamedTuple): + name: str + to: str = None + type: str = None + alias: str = None + cardinality: str = None + required: bool = None + from_: str = None + relation: str = None + source: str = None + target: str = None + properties: List[PropertyDetail] = None + + as_dict = as_dict + + +class CreateTypeBody(NamedTuple): + module: str + name: str + relation: str + properties: List[PropertyDetail] = None + links: List[LinkDetail] = None + + as_dict = as_dict + + +def create_facility(link_to, link_card='single', has_link_prop=True): + if has_link_prop or link_card == 'multi': + if link_card == 'single': + relation = "(select distinct on (facid) * from cd.bookings) as bookings" + else: + relation = "(select distinct on (facid, memid) * from cd.bookings) as bookings" + link_detail = [ + LinkDetail( + name="bookedby", + cardinality=link_card, + relation=relation, + type=link_to, + source="facid", + target="memid", + from_="fid", + to="mid", + properties=[ + PropertyDetail( + name="starttime", + type="timestamp", + alias="start_at" + ), + PropertyDetail( + name="slots", + type="int4" + ) + ] if has_link_prop else None, + ), + ] + else: + link_detail = [ + LinkDetail( + name="booked_by", + type=link_to, + from_="fid", + to="mid", + ) + ] + + return CreateTypeBody( + module="default", + name="Facility", + relation="cd.facilities", + properties=[ + PropertyDetail( + name="facid", + type="int4", + alias="fid", + exclusive=True, + ), + PropertyDetail( + name="membercost", + type="numeric", + alias="member_cost" + ), + PropertyDetail( + name="name", + type="str" + ), + PropertyDetail( + name="guestcost", + type="numeric", + alias="guest_cost" + ), + PropertyDetail( + name="discount", + expr=".member_cost / .guest_cost" + ) + ], + links=link_detail + ) + + +def create_member_from_outer(): + return CreateTypeBody( + module="default", + name="Member", + relation="cd.members", + properties=[ + PropertyDetail( + name="memid", + type="int4", + alias="mid", + exclusive=True, + ), + PropertyDetail( + name="surname", + type="str", + ), + PropertyDetail( + name="firstname", + type="str" + ), + PropertyDetail( + name="address", + type="str", + ), + PropertyDetail( + name="zipcode", + type="int4" + ), + PropertyDetail( + name="telephone", + type="str", + alias="phone_number" + ), + PropertyDetail( + name="recommendedby", + type="int4" + ), + PropertyDetail( + name="joindate", + type="timestamp" + ) + ], + ) + + +body = create_member_from_outer() + +r = requests.post('http://127.0.0.1:5656/db/demo_dump/extern/create-type', + json=body.as_dict()) +print(r.text) + +body = create_facility('Member', 'multi') + +r = requests.post('http://127.0.0.1:5656/db/demo_dump/extern/create-type', + json=body.as_dict()) +print(r.text) diff --git a/scripts/dump_restore_v0.py b/scripts/dump_restore_v0.py new file mode 100644 index 00000000000..77e4f666456 --- /dev/null +++ b/scripts/dump_restore_v0.py @@ -0,0 +1,284 @@ +import asyncio +import json +import os + +from asyncpg import connection as pg_connection +import edgedb + + +class AsyncBuffer: + def __init__(self): + self.datas = [] + + async def append(self, value: bytes): + self.datas.append(value) + + async def __aiter__(self): + for value in self.datas: + yield value + + +async def main(dbname, valid_query, expect_result, sql_for_external=None): + # edgedb==0.21.0 + print(f"Is v0 protocol: {tuple([int(i) for i in edgedb.__version__.split('.')]) <= (0, 23, 0)}") + client_from = edgedb.create_async_client(host='127.0.0.1', port=5656, database=dbname, + tls_security='insecure') + header = AsyncBuffer() + body = AsyncBuffer() + try: + await client_from.ensure_connected() + + async with client_from._acquire() as conn: + await conn._inner._impl._protocol.dump(header_callback=header.append, block_callback=body.append) + + print("Dump done.") + + if ( + (data := (await client_from.query_json(f"select sys::Database filter .name='restored_{dbname}'"))) + and len(json.loads(data)) > 0 + ): + await client_from.execute(f"drop database restored_{dbname}") + await client_from.execute(f"create database restored_{dbname}") + if sql_for_external is not None: + pg_conn = await pg_connection.connect( + dsn=f"postgresql://postgres:@127.0.0.1:5432?database=V2f147ded60_restored_{dbname}" + ) + await pg_conn.execute(sql_for_external) + await pg_conn.close() + + print("To db prepared.") + client_to = edgedb.create_async_client(host='127.0.0.1', port=5656, database=f'restored_{dbname}', + tls_security='insecure') + try: + await client_to.ensure_connected() + async with client_to._acquire() as conn: + await conn._inner._impl._protocol.restore(header=b"".join(header.datas), data_gen=body) + print("Restore done.") + + data = await client_to.query_json(valid_query) + assert json.loads(data) == expect_result + + print('Data valid successfully.') + + finally: + await client_to.aclose() + finally: + await client_from.aclose() + + +if __name__ == '__main__': + # case no external + asyncio.run(main( + 'cards', + """ + SELECT User { + name, + deck: { + name, + element, + cost, + @count + } ORDER BY @count DESC THEN .name ASC + } ORDER BY .name + """, + [ + { + 'name': 'Alice', + 'deck': [ + { + 'cost': 2, + 'name': 'Bog monster', + '@count': 3, + 'element': 'Water' + }, + { + 'cost': 3, + 'name': 'Giant turtle', + '@count': 3, + 'element': 'Water' + }, + { + 'cost': 5, + 'name': 'Dragon', + '@count': 2, + 'element': 'Fire' + }, + { + 'cost': 1, + 'name': 'Imp', + '@count': 2, + 'element': 'Fire' + }, + ], + }, + { + 'name': 'Bob', + 'deck': [ + { + 'cost': 2, + 'name': 'Bog monster', + '@count': 3, + 'element': 'Water' + }, + { + 'cost': 1, + 'name': 'Dwarf', + '@count': 3, + 'element': 'Earth' + }, + { + 'cost': 3, + 'name': 'Giant turtle', + '@count': 3, + 'element': 'Water' + }, + { + 'cost': 3, + 'name': 'Golem', + '@count': 3, + 'element': 'Earth' + }, + ], + }, + { + 'name': 'Carol', + 'deck': [ + { + 'cost': 1, + 'name': 'Dwarf', + '@count': 4, + 'element': 'Earth' + }, + { + 'cost': 1, + 'name': 'Sprite', + '@count': 4, + 'element': 'Air' + }, + { + 'cost': 2, + 'name': 'Bog monster', + '@count': 3, + 'element': 'Water' + }, + { + 'cost': 2, + 'name': 'Giant eagle', + '@count': 3, + 'element': 'Air' + }, + { + 'cost': 3, + 'name': 'Giant turtle', + '@count': 2, + 'element': 'Water' + }, + { + 'cost': 3, + 'name': 'Golem', + '@count': 2, + 'element': 'Earth' + }, + { + 'cost': 4, + 'name': 'Djinn', + '@count': 1, + 'element': 'Air' + }, + ], + }, + { + 'name': 'Dave', + 'deck': [ + { + 'cost': 1, + 'name': 'Sprite', + '@count': 4, + 'element': 'Air' + }, + { + 'cost': 2, + 'name': 'Bog monster', + '@count': 1, + 'element': 'Water' + }, + { + 'cost': 4, + 'name': 'Djinn', + '@count': 1, + 'element': 'Air' + }, + { + 'cost': 5, + 'name': 'Dragon', + '@count': 1, + 'element': 'Fire' + }, + { + 'cost': 2, + 'name': 'Giant eagle', + '@count': 1, + 'element': 'Air' + }, + { + 'cost': 3, + 'name': 'Giant turtle', + '@count': 1, + 'element': 'Water' + }, + { + 'cost': 3, + 'name': 'Golem', + '@count': 1, + 'element': 'Earth' + }, + ], + } + ] + )) + PROJECT_PATH = os.path.dirname(os.path.dirname(__file__)) + # case with external + with open(os.path.join(PROJECT_PATH, 'tests', 'schemas', 'http_create_type.sql'), 'rt') as f: + outter_schema = f.read() + asyncio.run(main( + 'demo_dump', + """ + SELECT distinct + Facility { + fid, name, + bookedby: { + mid, + fullname := .surname ++ '.' ++ .firstname + } ORDER BY .mid, + } FILTER .fid=1; + """, + [{ + 'fid': 1, + 'name': 'Tennis Court 2', + 'bookedby': [ + {'mid': 0, 'fullname': 'GUEST.GUEST'}, + {'mid': 1, 'fullname': 'Smith.Darren'}, + {'mid': 2, 'fullname': 'Smith.Tracy'}, + {'mid': 3, 'fullname': 'Rownam.Tim'}, + {'mid': 4, 'fullname': 'Joplette.Janice'}, + {'mid': 5, 'fullname': 'Butters.Gerald'}, + {'mid': 6, 'fullname': 'Tracy.Burton'}, + {'mid': 7, 'fullname': 'Dare.Nancy'}, + {'mid': 8, 'fullname': 'Boothe.Tim'}, + {'mid': 9, 'fullname': 'Stibbons.Ponder'}, + {'mid': 10, 'fullname': 'Owen.Charles'}, + {'mid': 11, 'fullname': 'Jones.David'}, + {'mid': 12, 'fullname': 'Baker.Anne'}, + {'mid': 13, 'fullname': 'Farrell.Jemima'}, + {'mid': 14, 'fullname': 'Smith.Jack'}, + {'mid': 15, 'fullname': 'Bader.Florence'}, + {'mid': 16, 'fullname': 'Baker.Timothy'}, + {'mid': 24, 'fullname': 'Sarwin.Ramnaresh'}, + {'mid': 27, 'fullname': 'Rumney.Henrietta'}, + {'mid': 28, 'fullname': 'Farrell.David'}, + {'mid': 30, 'fullname': 'Purview.Millicent'}, + {'mid': 35, 'fullname': 'Hunt.John'}, + ] + }], + outter_schema + )) From b879622c67c945eadca2d2deb56d8bc3ce050074 Mon Sep 17 00:00:00 2001 From: JMPRIEST Date: Mon, 15 May 2023 10:46:33 +0800 Subject: [PATCH 08/20] =?UTF-8?q?:mute:=20=E5=88=A0=E9=99=A4dump&restore?= =?UTF-8?q?=20debug=E6=9C=9F=E9=97=B4=E7=9A=84=E6=97=A5=E5=BF=97=E6=89=93?= =?UTF-8?q?=E5=8D=B0=20:art:=20fix=20typo?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- edb/server/dbview/dbview.pyx | 4 ++-- edb/server/protocol/binary.pyx | 3 --- edb/server/protocol/execute.pyx | 2 +- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/edb/server/dbview/dbview.pyx b/edb/server/dbview/dbview.pyx index 320efcede69..1479ea818db 100644 --- a/edb/server/dbview/dbview.pyx +++ b/edb/server/dbview/dbview.pyx @@ -1085,7 +1085,7 @@ cdef class DatabaseConnectionView: await be_conn.sql_execute(sqls) self._in_tx_sp_sqls.clear() - def save_schema_mutaion(self, mut, mut_bytes): + def save_schema_mutation(self, mut, mut_bytes): self._db._index._server.get_compiler_pool().append_schema_mutation( self.dbname, mut_bytes, @@ -1106,7 +1106,7 @@ cdef class DatabaseConnectionView: and side_effects and (side_effects & SideEffects.SchemaChanges) ): - self.save_schema_mutaion( + self.save_schema_mutation( query_unit.user_schema_mutation_obj, query_unit.user_schema_mutation, ) diff --git a/edb/server/protocol/binary.pyx b/edb/server/protocol/binary.pyx index b259685bb65..184dcf8374c 100644 --- a/edb/server/protocol/binary.pyx +++ b/edb/server/protocol/binary.pyx @@ -2043,19 +2043,16 @@ cdef class EdgeConnection(frontend.FrontendConnection): # getting external ddl & external ids if hdrname == DUMP_EXTERNAL_VIEW: external_view_num = self.buffer.read_int32() - logger.info(external_view_num) for _ in range(external_view_num): key_flag = self.buffer.read_int16() if key_flag == DUMP_EXTERNAL_KEY_LINK: obj_name = self.buffer.read_len_prefixed_utf8() link_name = self.buffer.read_len_prefixed_utf8() sql = self.buffer.read_len_prefixed_utf8() - logger.info((obj_name, link_name, sql)) external_views.append(((obj_name, link_name), sql)) else: name = self.buffer.read_len_prefixed_utf8() sql = self.buffer.read_len_prefixed_utf8() - logger.info((name, sql)) external_views.append((name, sql)) proto_major = self.buffer.read_int16() diff --git a/edb/server/protocol/execute.pyx b/edb/server/protocol/execute.pyx index 688f949119c..77eb584174e 100644 --- a/edb/server/protocol/execute.pyx +++ b/edb/server/protocol/execute.pyx @@ -321,7 +321,7 @@ async def execute_script( side_effects & dbview.SideEffects.SchemaChanges and group_mutation is not None ): - dbv.save_schema_mutaion(gmut_unpickled, group_mutation) + dbv.save_schema_mutation(gmut_unpickled, group_mutation) state = dbv.serialize_state() if state is not orig_state: From 6c18fdd521327a06bf437eab5627379f463de60e Mon Sep 17 00:00:00 2001 From: JMPRIEST Date: Mon, 15 May 2023 19:59:27 +0800 Subject: [PATCH 09/20] =?UTF-8?q?:white=5Fcheck=5Fmark:=20=E7=BC=96?= =?UTF-8?q?=E5=86=99=E5=A4=96=E9=83=A8=E8=A1=A8dump&restore=E6=B5=8B?= =?UTF-8?q?=E8=AF=95=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- edb/server/compiler/compiler.py | 8 +- edb/server/dbview/dbview.pxd | 2 +- edb/server/dbview/dbview.pyx | 10 +- edb/server/protocol/binary.pyx | 3 +- edb/server/protocol/execute.pyx | 6 +- edb/server/protocol/extern_obj.py | 6 +- edb/server/protocol/protocol.pxd | 1 + edb/server/protocol/protocol.pyx | 3 + edb/testbase/http.py | 1 + edb/testbase/server.py | 8 +- tests/test_http_create_type.py | 1661 +++++++++++++++-------------- 11 files changed, 892 insertions(+), 817 deletions(-) diff --git a/edb/server/compiler/compiler.py b/edb/server/compiler/compiler.py index 18774ff2502..2609f4ddf69 100644 --- a/edb/server/compiler/compiler.py +++ b/edb/server/compiler/compiler.py @@ -139,6 +139,8 @@ class CompileContext: external_view: Optional[Mapping] = immutables.Map() # If in restoring external view restoring_external: Optional[bool] = False + # If is test mode from http + testmode: Optional[bool] = False DEFAULT_MODULE_ALIASES_MAP = immutables.Map( @@ -373,7 +375,7 @@ def get_std_schema(self) -> s_schema.Schema: def _new_delta_context(self, ctx: CompileContext): context = s_delta.CommandContext() - context.testmode = self.get_config_val(ctx, '__internal_testmode') + context.testmode = self.get_config_val(ctx, '__internal_testmode') or ctx.testmode context.stdmode = ctx.bootstrap_mode context.internal_schema_mode = ctx.internal_schema_mode context.schema_object_ids = ctx.schema_object_ids @@ -2426,6 +2428,7 @@ def compile( module: Optional[str] = None, external_view: Optional[Mapping] = None, restoring_external: Optional[bool] = False, + testmode: bool = False ) -> Tuple[dbstate.QueryUnitGroup, Optional[dbstate.CompilerConnectionState]]: @@ -2471,7 +2474,8 @@ def compile( protocol_version=protocol_version, module=module, external_view=external_view, - restoring_external=restoring_external + restoring_external=restoring_external, + testmode=testmode ) unit_group = self._compile(ctx=ctx, source=source) diff --git a/edb/server/dbview/dbview.pxd b/edb/server/dbview/dbview.pxd index 2eddf70ef1d..0ccb3f2bd04 100644 --- a/edb/server/dbview/dbview.pxd +++ b/edb/server/dbview/dbview.pxd @@ -47,7 +47,7 @@ cdef class QueryRequestInfo: cdef public object module cdef public bint read_only cdef public object external_view - + cdef public bint testmode cdef int cached_hash diff --git a/edb/server/dbview/dbview.pyx b/edb/server/dbview/dbview.pyx index 1479ea818db..46c1ae54463 100644 --- a/edb/server/dbview/dbview.pyx +++ b/edb/server/dbview/dbview.pyx @@ -85,6 +85,7 @@ cdef class QueryRequestInfo: allow_capabilities: uint64_t = compiler.Capability.ALL, module: str = None, read_only: bint = False, + testmode: bint = False, external_view: object = immutables.Map(), ): self.source = source @@ -99,6 +100,7 @@ cdef class QueryRequestInfo: self.allow_capabilities = allow_capabilities self.module = module self.read_only = read_only + self.testmode = testmode self.external_view = external_view self.cached_hash = hash(( @@ -112,7 +114,8 @@ cdef class QueryRequestInfo: self.inline_typenames, self.inline_objectids, self.module, - self.read_only + self.read_only, + self.testmode )) def __hash__(self): @@ -130,7 +133,8 @@ cdef class QueryRequestInfo: self.inline_typenames == other.inline_typenames and self.inline_objectids == other.inline_objectids and self.module == other.module and - self.read_only == other.read_only + self.read_only == other.read_only and + self.testmode == other.testmode ) @@ -1405,6 +1409,8 @@ cdef class DatabaseConnectionView: query_req.input_format is compiler.InputFormat.JSON, query_req.module, query_req.external_view, + False, + query_req.testmode ) finally: metrics.edgeql_query_compilation_duration.observe( diff --git a/edb/server/protocol/binary.pyx b/edb/server/protocol/binary.pyx index 184dcf8374c..e1bcd6462d2 100644 --- a/edb/server/protocol/binary.pyx +++ b/edb/server/protocol/binary.pyx @@ -1825,6 +1825,7 @@ cdef class EdgeConnection(frontend.FrontendConnection): # # This guarantees that every pg connection and the compiler work # with the same DB state. + user_schema = await server.introspect_user_schema(dbname, pgcon) await pgcon.sql_execute( b'''START TRANSACTION @@ -1839,8 +1840,6 @@ cdef class EdgeConnection(frontend.FrontendConnection): SET statement_timeout = 0; ''', ) - - user_schema = await server.introspect_user_schema(dbname, pgcon) global_schema = await server.introspect_global_schema(pgcon) db_config = await server.introspect_db_config(pgcon) dump_protocol = self.max_protocol diff --git a/edb/server/protocol/execute.pyx b/edb/server/protocol/execute.pyx index 77eb584174e..af1b8570348 100644 --- a/edb/server/protocol/execute.pyx +++ b/edb/server/protocol/execute.pyx @@ -413,11 +413,12 @@ async def parse_execute( query: str, *, external_view: Mapping = immutables.Map(), + testmode: bool=False ): server = db.server dbv = await server.new_dbview( dbname=db.name, - query_cache=False , + query_cache=False, protocol_version=edbdef.CURRENT_PROTOCOL, ) @@ -427,7 +428,8 @@ async def parse_execute( input_format=compiler.InputFormat.JSON, output_format=compiler.OutputFormat.NONE, allow_capabilities=compiler.Capability.MODIFICATIONS | compiler.Capability.DDL, - external_view=external_view + external_view=external_view, + testmode=testmode ) compiled = await dbv.parse(query_req) diff --git a/edb/server/protocol/extern_obj.py b/edb/server/protocol/extern_obj.py index 8c640cc093a..972a1fb6b81 100644 --- a/edb/server/protocol/extern_obj.py +++ b/edb/server/protocol/extern_obj.py @@ -278,7 +278,6 @@ def to_ddl(self): return f"DROP TYPE {self.qualname}" - async def handle_request( request, response, @@ -336,7 +335,8 @@ def _unknown_path(): await execute.parse_execute( db, req.to_ddl(), - external_view=req.resolve_view() + external_view=req.resolve_view(), + testmode=bool(request.testmode) ) except Exception as ex: if debug.flags.server: @@ -356,5 +356,3 @@ def _unknown_path(): response.body = json.dumps({'error': err_dct}).encode() else: response.body = b'{"data": "ok"}' - - diff --git a/edb/server/protocol/protocol.pxd b/edb/server/protocol/protocol.pxd index 65c98745154..e6e6df61b24 100644 --- a/edb/server/protocol/protocol.pxd +++ b/edb/server/protocol/protocol.pxd @@ -33,6 +33,7 @@ cdef class HttpRequest: public bytes host public bytes authorization public object params + public bytes testmode cdef class HttpResponse: diff --git a/edb/server/protocol/protocol.pyx b/edb/server/protocol/protocol.pyx index 80c3e9e0450..3cdbea28d23 100644 --- a/edb/server/protocol/protocol.pyx +++ b/edb/server/protocol/protocol.pyx @@ -74,6 +74,7 @@ cdef class HttpRequest: self.body = b'' self.authorization = b'' self.content_type = b'' + self.testmode = b'' cdef class HttpResponse: @@ -216,6 +217,8 @@ cdef class HttpProtocol: self.current_request.params = {} param = name[len(b'x-edgedb-'):] self.current_request.params[param] = value + elif name == b'testmode': + self.current_request.testmode = value def on_body(self, body: bytes): self.current_request.body += body diff --git a/edb/testbase/http.py b/edb/testbase/http.py index 0080d2d79c2..0ad930c49b6 100644 --- a/edb/testbase/http.py +++ b/edb/testbase/http.py @@ -368,6 +368,7 @@ def create_type(self, body): req = urllib.request.Request(self.http_addr, method='POST') req.add_header('Content-Type', 'application/json') + req.add_header('testmode', '1') response = urllib.request.urlopen( req, json.dumps(req_data).encode(), context=self.tls_context ) diff --git a/edb/testbase/server.py b/edb/testbase/server.py index 639c7213268..b2fd415e6dc 100644 --- a/edb/testbase/server.py +++ b/edb/testbase/server.py @@ -788,7 +788,7 @@ async def assertRaisesRegexTx(self, exception, regex, msg=None, **kwargs): class PGConnMixin: @classmethod - def pg_conn(cls): + def pg_conn(cls, dbname: str = None): conn_spec = {} addrs, params = pgconnparams.parse_dsn(cls.backend_dsn) conn_spec['host'], conn_spec['port'] = addrs[0] @@ -814,7 +814,7 @@ def pg_conn(cls): return pgcon.connect( conn_spec, pgcommon.get_database_backend_name( - cls.get_database_name(), + dbname or cls.get_database_name(), tenant_id=instance_params.tenant_id ), pgparams.BackendRuntimeParams( @@ -1368,7 +1368,7 @@ async def check_dump_restore_single_db(self, check_method): self.run_cli('-d', dbname, 'restore', f.name) await check_method(self) - async def check_dump_restore(self, check_method): + async def check_dump_restore(self, check_method, restore_db_prepare=None): if not self.has_create_database: return await self.check_dump_restore_single_db(check_method) @@ -1380,6 +1380,8 @@ async def check_dump_restore(self, check_method): await self.con.execute(f'CREATE DATABASE {q_tgt_dbname}') try: + if restore_db_prepare is not None: + await restore_db_prepare() self.run_cli('-d', tgt_dbname, 'restore', f.name) con2 = await self.connect(database=tgt_dbname) except Exception: diff --git a/tests/test_http_create_type.py b/tests/test_http_create_type.py index 38340edff54..64648503352 100644 --- a/tests/test_http_create_type.py +++ b/tests/test_http_create_type.py @@ -1,7 +1,11 @@ import os +import re from typing import Dict, List, NamedTuple + import edgedb -from edb.testbase import http as tb + +from edb.testbase import http as http_tb +from edb.testbase import server as server_tb def as_dict(self=None) -> Dict: @@ -56,7 +60,718 @@ class CreateTypeBody(NamedTuple): as_dict = as_dict -class TestHttpCreateType(tb.ExternTestCase): +RE_CASE = re.compile( + 'test\_link\_(?P[a-z]*)\_(?P[a-z]*)(?:\_(?P[a-z]*)\_link)?(?:\_(?P[a-z\_]*))?' +) + + +class HttpCreateTypeMixin: + async def link_outer_outer_single_link(self): + # query object without link + await self.assert_query_result( + r''' + SELECT + Facility { + fid, discount, name + } FILTER .fid = 1; + ''', + [{ + 'fid': 1, + 'discount': 5 / 25, + 'name': 'Tennis Court 2' + }] + ) + # query object with link + await self.assert_query_result( + r''' + SELECT distinct + Facility { + fid, name, + booked_by: { + mid, + fullname := .surname ++ '.' ++ .firstname + } ORDER BY .mid, + } FILTER .fid=1; + ''', + [{ + 'fid': 1, + 'name': 'Tennis Court 2', + 'booked_by': {'mid': 3, 'fullname': 'Rownam.Tim'} + }] + ) + with self.assertRaisesRegex( + edgedb.InvalidTypeError, + regex="operator '=' cannot be applied to operands of type .*" + ): + await self.con.execute( + r''' + select Facility FILTER .fid = '1'; + ''' + ) + with self.assertRaisesRegex( + edgedb.InvalidTypeError, + regex="operator '=' cannot be applied to operands of type .*" + ): + await self.con.execute( + r''' + select Facility FILTER .booked_by.mid = '1'; + ''' + ) + + async def link_outer_outer_single_link_with_prop(self): + # query object without link + await self.assert_query_result( + r''' + SELECT + Facility { + fid, discount, name + } FILTER .fid = 1; + ''', + [{ + 'fid': 1, + 'discount': 5 / 25, + 'name': 'Tennis Court 2' + }] + ) + # query object with link + await self.assert_query_result( + r''' + SELECT Facility { + fid, name, + bookedby: { + mid, + fullname := .surname ++ '.' ++ .firstname + } ORDER BY .mid, + } FILTER .fid=1; + ''', + [{ + 'fid': 1, + 'name': 'Tennis Court 2', + 'bookedby': {'mid': 0, 'fullname': 'GUEST.GUEST'} + }] + ) + with self.assertRaisesRegex( + edgedb.InvalidTypeError, + regex="operator '=' cannot be applied to operands of type .*" + ): + await self.con.execute( + r''' + select Facility FILTER .fid = '1'; + ''' + ) + with self.assertRaisesRegex( + edgedb.InvalidTypeError, + regex="operator '=' cannot be applied to operands of type .*" + ): + await self.con.execute( + r''' + select Facility FILTER .bookedby.mid = '1'; + ''' + ) + + async def link_outer_outer_multi_link_with_prop(self): + # query object without link + await self.assert_query_result( + r''' + SELECT + Facility { + fid, discount, name + } FILTER .fid = 1; + ''', + [{ + 'fid': 1, + 'discount': 5 / 25, + 'name': 'Tennis Court 2' + }] + ) + # query object with link + await self.assert_query_result( + r''' + SELECT distinct + Facility { + fid, name, + bookedby: { + mid, + fullname := .surname ++ '.' ++ .firstname + } ORDER BY .mid, + } FILTER .fid=1; + ''', + [{ + 'fid': 1, + 'name': 'Tennis Court 2', + 'bookedby': [ + {'mid': 0, 'fullname': 'GUEST.GUEST'}, + {'mid': 1, 'fullname': 'Smith.Darren'}, + {'mid': 2, 'fullname': 'Smith.Tracy'}, + {'mid': 3, 'fullname': 'Rownam.Tim'}, + {'mid': 4, 'fullname': 'Joplette.Janice'}, + {'mid': 5, 'fullname': 'Butters.Gerald'}, + {'mid': 6, 'fullname': 'Tracy.Burton'}, + {'mid': 7, 'fullname': 'Dare.Nancy'}, + {'mid': 8, 'fullname': 'Boothe.Tim'}, + {'mid': 9, 'fullname': 'Stibbons.Ponder'}, + {'mid': 10, 'fullname': 'Owen.Charles'}, + {'mid': 11, 'fullname': 'Jones.David'}, + {'mid': 12, 'fullname': 'Baker.Anne'}, + {'mid': 13, 'fullname': 'Farrell.Jemima'}, + {'mid': 14, 'fullname': 'Smith.Jack'}, + {'mid': 15, 'fullname': 'Bader.Florence'}, + {'mid': 16, 'fullname': 'Baker.Timothy'}, + {'mid': 24, 'fullname': 'Sarwin.Ramnaresh'}, + {'mid': 27, 'fullname': 'Rumney.Henrietta'}, + {'mid': 28, 'fullname': 'Farrell.David'}, + {'mid': 30, 'fullname': 'Purview.Millicent'}, + {'mid': 35, 'fullname': 'Hunt.John'}, + ] + }] + ) + with self.assertRaisesRegex( + edgedb.InvalidTypeError, + regex="operator '=' cannot be applied to operands of type .*" + ): + await self.con.execute( + r''' + select Facility FILTER .fid = '1'; + ''' + ) + with self.assertRaisesRegex( + edgedb.InvalidTypeError, + regex="operator '=' cannot be applied to operands of type .*" + ): + await self.con.execute( + r''' + select Facility FILTER .bookedby.mid = '1'; + ''' + ) + + async def link_outer_inner_single_link(self): + link_to = "Person" + # query object without link + await self.assert_query_result( + r''' + SELECT + Facility { + fid, discount, name + } FILTER .fid = 1; + ''', + [{ + 'fid': 1, + 'discount': 5 / 25, + 'name': 'Tennis Court 2' + }] + ) + # query object with link + await self.assert_query_result( + r''' + SELECT distinct + Facility { + fid, name, + booked_by: { + mid, + fullname := .surname ++ '.' ++ .firstname + } ORDER BY .mid, + } FILTER .fid=1; + ''', + [{ + 'fid': 1, + 'name': 'Tennis Court 2', + 'booked_by': {'mid': 3, 'fullname': 'Rownam.Tim'} + }] + ) + with self.assertRaisesRegex( + edgedb.ConstraintViolationError, + regex='deletion of .* is prohibited by link target policy' + ): + await self.con.execute( + f''' + delete {link_to} FILTER .mid=6; + ''' + ) + with self.assertRaisesRegex( + edgedb.InvalidTypeError, + regex="operator '=' cannot be applied to operands of type .*" + ): + await self.con.execute( + r''' + select Facility FILTER .fid = '1'; + ''' + ) + with self.assertRaisesRegex( + edgedb.InvalidTypeError, + regex="operator '=' cannot be applied to operands of type .*" + ): + await self.con.execute( + r''' + select Facility FILTER .booked_by.mid = '1'; + ''' + ) + + async def link_outer_inner_single_link_with_prop(self): + link_to = "Person" + # query object without link + await self.assert_query_result( + r''' + SELECT + Facility { + fid, discount, name + } FILTER .fid = 1; + ''', + [{ + 'fid': 1, + 'discount': 5 / 25, + 'name': 'Tennis Court 2' + }] + ) + # query object with link + await self.assert_query_result( + r''' + SELECT Facility { + fid, name, + bookedby: { + mid, + fullname := .surname ++ '.' ++ .firstname + } ORDER BY .mid, + } FILTER .fid=1; + ''', + [{ + 'fid': 1, + 'name': 'Tennis Court 2', + 'bookedby': {'mid': 0, 'fullname': 'GUEST.GUEST'} + }] + ) + + with self.assertRaisesRegex( + edgedb.ConstraintViolationError, + regex='deletion of .* is prohibited by link target policy' + ): + await self.con.execute( + f''' + delete {link_to} FILTER .mid=0; + ''' + ) + with self.assertRaisesRegex( + edgedb.InvalidTypeError, + regex="operator '=' cannot be applied to operands of type .*" + ): + await self.con.execute( + r''' + select Facility FILTER .fid = '1'; + ''' + ) + with self.assertRaisesRegex( + edgedb.InvalidTypeError, + regex="operator '=' cannot be applied to operands of type .*" + ): + await self.con.execute( + r''' + select Facility FILTER .bookedby.mid = '1'; + ''' + ) + + async def link_outer_inner_multi_link_with_prop(self): + link_to = "Person" + # query object without link + await self.assert_query_result( + r''' + SELECT + Facility { + fid, discount, name + } FILTER .fid = 1; + ''', + [{ + 'fid': 1, + 'discount': 5 / 25, + 'name': 'Tennis Court 2' + }] + ) + # query object with link + await self.assert_query_result( + r''' + SELECT distinct + Facility { + fid, name, + bookedby: { + mid, + fullname := .surname ++ '.' ++ .firstname + } ORDER BY .mid, + } FILTER .fid=1; + ''', + [{ + 'fid': 1, + 'name': 'Tennis Court 2', + 'bookedby': [ + {'mid': 0, 'fullname': 'GUEST.GUEST'}, + {'mid': 1, 'fullname': 'Smith.Darren'}, + {'mid': 2, 'fullname': 'Smith.Tracy'}, + {'mid': 3, 'fullname': 'Rownam.Tim'}, + {'mid': 4, 'fullname': 'Joplette.Janice'}, + {'mid': 5, 'fullname': 'Butters.Gerald'}, + {'mid': 6, 'fullname': 'Tracy.Burton'}, + {'mid': 7, 'fullname': 'Dare.Nancy'}, + {'mid': 8, 'fullname': 'Boothe.Tim'}, + {'mid': 9, 'fullname': 'Stibbons.Ponder'}, + {'mid': 10, 'fullname': 'Owen.Charles'}, + {'mid': 11, 'fullname': 'Jones.David'}, + {'mid': 12, 'fullname': 'Baker.Anne'}, + {'mid': 13, 'fullname': 'Farrell.Jemima'}, + {'mid': 14, 'fullname': 'Smith.Jack'}, + {'mid': 15, 'fullname': 'Bader.Florence'}, + {'mid': 16, 'fullname': 'Baker.Timothy'}, + {'mid': 24, 'fullname': 'Sarwin.Ramnaresh'}, + {'mid': 27, 'fullname': 'Rumney.Henrietta'}, + {'mid': 28, 'fullname': 'Farrell.David'}, + {'mid': 30, 'fullname': 'Purview.Millicent'}, + {'mid': 35, 'fullname': 'Hunt.John'}, + ] + }] + ) + + with self.assertRaisesRegex( + edgedb.ConstraintViolationError, + regex='deletion of .* is prohibited by link target policy' + ): + await self.con.execute( + f''' + delete {link_to} FILTER .mid=5; + ''' + ) + with self.assertRaisesRegex( + edgedb.InvalidTypeError, + regex="operator '=' cannot be applied to operands of type .*" + ): + await self.con.execute( + r''' + select Facility FILTER .fid = '1'; + ''' + ) + with self.assertRaisesRegex( + edgedb.InvalidTypeError, + regex="operator '=' cannot be applied to operands of type .*" + ): + await self.con.execute( + r''' + select Facility FILTER .bookedby.mid = '1'; + ''' + ) + + async def link_inner_outer(self): + with self.assertRaisesRegex( + edgedb.SchemaDefinitionError, + regex="target_property is required in " + "create link 'member' from object type 'default::NameList' " + "to external object type 'default::Member'." + ): + await self.con.execute( + ''' + create type NameList{ + create property _id -> std::int32 { + create constraint std::exclusive; + }; + create link member -> Member; + create property alive -> std::bool; + }; + ''' + ) + await self.con.execute( + ''' + create type NameList{ + create property _id -> std::int32 { + create constraint std::exclusive; + }; + create link member -> Member{ + on _id to mid + }; + create property alive -> std::bool; + }; + ''' + ) + self.new_outter_type.append("NameList") + with self.assertRaisesRegex( + edgedb.SchemaDefinitionError, + regex="target_property is required in " + "alter link 'member' from object type 'default::NameList' " + "to external object type 'default::Member'." + ): + await self.con.execute( + ''' + alter type NameList{ + alter link member {on id to id}; + }; + ''' + ) + await self.con.execute( + ''' + insert NameList{ + _id := 0, alive := true, + member:= (select Member filter .mid = 0) + }; + insert NameList{ + _id := 1, alive := false, + member:= (select Member filter .mid = 1) + }; + insert NameList{ + _id := 2, alive := true, + member:= (select Member filter .mid = 2) + }; + ''' + ) + # query object with link + await self.assert_query_result( + r''' + SELECT NameList + { + member: { + mid, + fullname := .surname ++ '.' ++ .firstname + } Limit 1 + } FILTER .alive=true + ORDER BY .member.mid; + ''', + [ + {'member': {'mid': 0, 'fullname': 'GUEST.GUEST'}}, + {'member': {'mid': 2, 'fullname': 'Smith.Tracy'}}, + ] + ) + + with self.assertRaisesRegex( + edgedb.InvalidTypeError, + regex="operator '=' cannot be applied to operands of type .*" + ): + await self.con.execute( + r''' + select NameList FILTER .member.mid = '1'; + ''' + ) + + async def link_inner_outer_on_source_delete(self): + await self.con.execute( + ''' + create type NameList{ + create property _id -> std::int32 { + create constraint std::exclusive; + }; + create link member -> Member { + on _id to mid; + on source delete delete target; + }; + create property alive -> std::bool; + }; + ''' + ) + await self.con.execute( + ''' + insert NameList{ + _id := 0, alive := true, + member:= (select Member filter .mid = 0) + }; + insert NameList{ + _id := 1, alive := false, + member:= (select Member filter .mid = 1) + }; + insert NameList{ + _id := 2, alive := true, + member:= (select Member filter .mid = 2) + }; + ''' + ) + # query object with link + await self.assert_query_result( + r''' + SELECT NameList + { + member: { + mid, + fullname := .surname ++ '.' ++ .firstname + } Limit 1 + } FILTER .alive=true + ORDER BY .member.mid; + ''', + [ + { + 'member': {'mid': 0, 'fullname': 'GUEST.GUEST'} + }, + { + 'member': {'mid': 2, 'fullname': 'Smith.Tracy'} + }, + ] + ) + await self.assert_query_result( + r''' + SELECT count(Member); + ''', + [ + 32 + ] + ) + await self.con.execute( + f''' + delete NameList FILTER ._id=0; + ''' + ) + await self.assert_query_result( + r''' + SELECT NameList + { + member: { + mid, + fullname := .surname ++ '.' ++ .firstname + } Limit 1 + } FILTER .alive=true; + ''', + [ + { + 'member': {'mid': 2, 'fullname': 'Smith.Tracy'} + }, + ] + ) + await self.assert_query_result( + r''' + SELECT count(Member); + ''', + [ + 32 + ] + ) + self.new_outter_type.append("NameList") + + async def link_inner_outer_multi_link_on_source_delete(self): + await self.con.execute( + ''' + create type NameList{ + create property _id -> std::int32 { + create constraint std::exclusive; + }; + create multi link member -> Member { + on _id to mid; + on source delete delete target; + }; + create property alive -> std::bool; + }; + ''' + ) + await self.con.execute( + ''' + insert NameList{ + _id := 0, alive := true, + member:= (select Member filter .mid = 0) + }; + insert NameList{ + _id := 1, alive := false, + member:= (select Member filter .mid = 1) + }; + insert NameList{ + _id := 2, alive := true, + member:= (select Member filter .mid = 2) + }; + ''' + ) + # query object with link + await self.assert_query_result( + r''' + SELECT NameList + { + member: { + mid, + fullname := .surname ++ '.' ++ .firstname + } Limit 1 + } FILTER .alive=true; + ''', + [ + { + 'member': [{'mid': 0, 'fullname': 'GUEST.GUEST'}] + }, + { + 'member': [{'mid': 2, 'fullname': 'Smith.Tracy'}] + }, + ] + ) + await self.assert_query_result( + r''' + SELECT count(Member); + ''', + [ + 32 + ] + ) + await self.con.execute( + f''' + delete NameList FILTER ._id=0; + ''' + ) + await self.assert_query_result( + r''' + SELECT NameList + { + member: { + mid, + fullname := .surname ++ '.' ++ .firstname + } Limit 1 + } FILTER .alive=true; + ''', + [ + { + 'member': [{'mid': 2, 'fullname': 'Smith.Tracy'}] + }, + ] + ) + await self.assert_query_result( + r''' + SELECT count(Member); + ''', + [ + 32 + ] + ) + + with self.assertRaisesRegex( + edgedb.InvalidTypeError, + regex="operator '=' cannot be applied to operands of type .*" + ): + await self.con.execute( + r''' + select NameList FILTER .member.mid = '1'; + ''' + ) + + self.new_outter_type.append("NameList") + + async def dml_reject(self): + with self.assertRaisesRegex( + edgedb.QueryError, + regex='External .* is read-only.' + ): + await self.con.execute( + ''' + delete Member filter .mid = 0; + ''' + ) + with self.assertRaisesRegex( + edgedb.QueryError, + regex='External .* is read-only.' + ): + await self.con.execute( + ''' + update Member filter .mid = 0 + set {mid := 999}; + ''' + ) + with self.assertRaisesRegex( + edgedb.QueryError, + regex='External .* is read-only.' + ): + await self.con.execute( + ''' + insert Member + { + mid := 999, surname := 'T', firstname := 'E', + address := 'Space', phone_number := '31415926', + joindate := '1970-01-01 00:00:00' + }; + ''' + ) + + +class TestHttpCreateType(http_tb.ExternTestCase, HttpCreateTypeMixin): SCHEMA = os.path.join( os.path.dirname(__file__), 'schemas', 'http_create_type.esdl' @@ -69,9 +784,15 @@ class TestHttpCreateType(tb.ExternTestCase): # EdgeQL/HTTP queries cannot run in a transaction TRANSACTION_ISOLATION = False + new_outter_type = [] + @classmethod def setUpClass(cls): super().setUpClass() + cls.loop.run_until_complete(cls.prepare_external_db()) + + @classmethod + async def prepare_external_db(cls, dbname: str = None): outter_sql_path = os.path.join( os.path.dirname(__file__), 'schemas', 'http_create_type.sql' @@ -81,12 +802,48 @@ def setUpClass(cls): outter_schema = f.read() else: raise OSError(f'Sql file with path : {outter_sql_path} for outter schema not found.') - conn = cls.loop.run_until_complete(cls.pg_conn()) + conn = await cls.pg_conn(dbname) try: - cls.loop.run_until_complete(conn.sql_execute(outter_schema.encode())) + await conn.sql_execute(outter_schema.encode()) finally: conn.terminate() + def setUp(self): + super().setUp() + self.create_relation_with_external() + + def create_relation_with_external(self): + self.new_outter_type = [] + case_name = self._testMethodName + if match := RE_CASE.match(case_name): + from_, to_, link_type, link_prop = match.groups() + if to_ == 'outer': + self.assertTrue(self.create_member_from_outer()) + self.new_outter_type.append('Member') + link_to = "Member" + else: + link_to = "Person" + + if link_prop == 'with_prop': + has_link_prop = True + else: + has_link_prop = False + + if from_ == 'outer': + self.assertTrue(self.create_facility(link_to, link_type, has_link_prop)) + self.new_outter_type.append('Facility') + else: + self.assertTrue(self.create_member_from_outer()) + self.new_outter_type.append('Member') + + def tearDown(self): + try: + while len(self.new_outter_type): + t = self.new_outter_type.pop() + self.loop.run_until_complete(self.con.execute(f'Drop type {t};')) + finally: + super().tearDown() + def create_facility(self, link_to, link_card='single', has_link_prop=True): if has_link_prop or link_card == 'multi': if link_card == 'single': @@ -94,28 +851,28 @@ def create_facility(self, link_to, link_card='single', has_link_prop=True): else: relation = "(select distinct on (facid, memid) * from cd.bookings) as bookings" link_detail = [ - LinkDetail( - name="bookedby", - cardinality=link_card, - relation=relation, - type=link_to, - source="facid", - target="memid", - from_="fid", - to="mid", - properties=[ - PropertyDetail( - name="starttime", - type="timestamp", - alias="start_at" - ), - PropertyDetail( - name="slots", - type="int4" - ) - ] if has_link_prop else None, - ), - ] + LinkDetail( + name="bookedby", + cardinality=link_card, + relation=relation, + type=link_to, + source="facid", + target="memid", + from_="fid", + to="mid", + properties=[ + PropertyDetail( + name="starttime", + type="timestamp", + alias="start_at" + ), + PropertyDetail( + name="slots", + type="int4" + ) + ] if has_link_prop else None, + ), + ] else: link_detail = [ LinkDetail( @@ -208,797 +965,99 @@ def create_member_from_outer(self): ) async def test_link_outer_outer_single_link(self): - self.assertTrue(self.create_member_from_outer()) - link_to = "Member" - self.assertTrue(self.create_facility(link_to, has_link_prop=False)) - try: - # query object without link - await self.assert_query_result( - r''' - SELECT - Facility { - fid, discount, name - } FILTER .fid = 1; - ''', - [{ - 'fid': 1, - 'discount': 5 / 25, - 'name': 'Tennis Court 2' - }] - ) - # query object with link - await self.assert_query_result( - r''' - SELECT distinct - Facility { - fid, name, - booked_by: { - mid, - fullname := .surname ++ '.' ++ .firstname - } ORDER BY .mid, - } FILTER .fid=1; - ''', - [{ - 'fid': 1, - 'name': 'Tennis Court 2', - 'booked_by': {'mid': 3, 'fullname': 'Rownam.Tim'} - }] - ) - with self.assertRaisesRegex( - edgedb.InvalidTypeError, - regex="operator '=' cannot be applied to operands of type .*" - ): - await self.con.execute( - r''' - select Facility FILTER .fid = '1'; - ''' - ) - with self.assertRaisesRegex( - edgedb.InvalidTypeError, - regex="operator '=' cannot be applied to operands of type .*" - ): - await self.con.execute( - r''' - select Facility FILTER .booked_by.mid = '1'; - ''' - ) - - finally: - await self.con.execute( - ''' - Drop type Facility; - Drop type Member; - ''' - ) + await self.link_outer_outer_single_link() async def test_link_outer_outer_single_link_with_prop(self): - self.assertTrue(self.create_member_from_outer()) - link_to = "Member" - self.assertTrue(self.create_facility(link_to)) - try: - # query object without link - await self.assert_query_result( - r''' - SELECT - Facility { - fid, discount, name - } FILTER .fid = 1; - ''', - [{ - 'fid': 1, - 'discount': 5 / 25, - 'name': 'Tennis Court 2' - }] - ) - # query object with link - await self.assert_query_result( - r''' - SELECT Facility { - fid, name, - bookedby: { - mid, - fullname := .surname ++ '.' ++ .firstname - } ORDER BY .mid, - } FILTER .fid=1; - ''', - [{ - 'fid': 1, - 'name': 'Tennis Court 2', - 'bookedby': {'mid': 0, 'fullname': 'GUEST.GUEST'} - }] - ) - with self.assertRaisesRegex( - edgedb.InvalidTypeError, - regex="operator '=' cannot be applied to operands of type .*" - ): - await self.con.execute( - r''' - select Facility FILTER .fid = '1'; - ''' - ) - with self.assertRaisesRegex( - edgedb.InvalidTypeError, - regex="operator '=' cannot be applied to operands of type .*" - ): - await self.con.execute( - r''' - select Facility FILTER .bookedby.mid = '1'; - ''' - ) - finally: - await self.con.execute( - ''' - Drop type Facility; - Drop type Member; - ''' - ) + await self.link_outer_outer_single_link_with_prop() async def test_link_outer_outer_multi_link_with_prop(self): - self.assertTrue(self.create_member_from_outer()) - link_to = "Member" - self.assertTrue(self.create_facility(link_to, 'multi')) - try: - # query object without link - await self.assert_query_result( - r''' - SELECT - Facility { - fid, discount, name - } FILTER .fid = 1; - ''', - [{ - 'fid': 1, - 'discount': 5 / 25, - 'name': 'Tennis Court 2' - }] - ) - # query object with link - await self.assert_query_result( - r''' - SELECT distinct - Facility { - fid, name, - bookedby: { - mid, - fullname := .surname ++ '.' ++ .firstname - } ORDER BY .mid, - } FILTER .fid=1; - ''', - [{ - 'fid': 1, - 'name': 'Tennis Court 2', - 'bookedby': [ - {'mid': 0, 'fullname': 'GUEST.GUEST'}, - {'mid': 1, 'fullname': 'Smith.Darren'}, - {'mid': 2, 'fullname': 'Smith.Tracy'}, - {'mid': 3, 'fullname': 'Rownam.Tim'}, - {'mid': 4, 'fullname': 'Joplette.Janice'}, - {'mid': 5, 'fullname': 'Butters.Gerald'}, - {'mid': 6, 'fullname': 'Tracy.Burton'}, - {'mid': 7, 'fullname': 'Dare.Nancy'}, - {'mid': 8, 'fullname': 'Boothe.Tim'}, - {'mid': 9, 'fullname': 'Stibbons.Ponder'}, - {'mid': 10, 'fullname': 'Owen.Charles'}, - {'mid': 11, 'fullname': 'Jones.David'}, - {'mid': 12, 'fullname': 'Baker.Anne'}, - {'mid': 13, 'fullname': 'Farrell.Jemima'}, - {'mid': 14, 'fullname': 'Smith.Jack'}, - {'mid': 15, 'fullname': 'Bader.Florence'}, - {'mid': 16, 'fullname': 'Baker.Timothy'}, - {'mid': 24, 'fullname': 'Sarwin.Ramnaresh'}, - {'mid': 27, 'fullname': 'Rumney.Henrietta'}, - {'mid': 28, 'fullname': 'Farrell.David'}, - {'mid': 30, 'fullname': 'Purview.Millicent'}, - {'mid': 35, 'fullname': 'Hunt.John'}, - ] - }] - ) - with self.assertRaisesRegex( - edgedb.InvalidTypeError, - regex="operator '=' cannot be applied to operands of type .*" - ): - await self.con.execute( - r''' - select Facility FILTER .fid = '1'; - ''' - ) - with self.assertRaisesRegex( - edgedb.InvalidTypeError, - regex="operator '=' cannot be applied to operands of type .*" - ): - await self.con.execute( - r''' - select Facility FILTER .bookedby.mid = '1'; - ''' - ) - finally: - await self.con.execute( - ''' - Drop type Facility; - Drop type Member; - ''' - ) + await self.link_outer_outer_multi_link_with_prop() async def test_link_outer_inner_single_link(self): - link_to = "Person" - self.assertTrue(self.create_facility(link_to, has_link_prop=False)) - try: - # query object without link - await self.assert_query_result( - r''' - SELECT - Facility { - fid, discount, name - } FILTER .fid = 1; - ''', - [{ - 'fid': 1, - 'discount': 5 / 25, - 'name': 'Tennis Court 2' - }] - ) - # query object with link - await self.assert_query_result( - r''' - SELECT distinct - Facility { - fid, name, - booked_by: { - mid, - fullname := .surname ++ '.' ++ .firstname - } ORDER BY .mid, - } FILTER .fid=1; - ''', - [{ - 'fid': 1, - 'name': 'Tennis Court 2', - 'booked_by': {'mid': 3, 'fullname': 'Rownam.Tim'} - }] - ) - with self.assertRaisesRegex( - edgedb.ConstraintViolationError, - regex='deletion of .* is prohibited by link target policy' - ): - await self.con.execute( - f''' - delete {link_to} FILTER .mid=6; - ''' - ) - with self.assertRaisesRegex( - edgedb.InvalidTypeError, - regex="operator '=' cannot be applied to operands of type .*" - ): - await self.con.execute( - r''' - select Facility FILTER .fid = '1'; - ''' - ) - with self.assertRaisesRegex( - edgedb.InvalidTypeError, - regex="operator '=' cannot be applied to operands of type .*" - ): - await self.con.execute( - r''' - select Facility FILTER .booked_by.mid = '1'; - ''' - ) - finally: - await self.con.execute( - ''' - Drop type Facility; - ''' - ) + await self.link_outer_inner_single_link() async def test_link_outer_inner_single_link_with_prop(self): - link_to = "Person" - self.assertTrue(self.create_facility(link_to)) - try: - # query object without link - await self.assert_query_result( - r''' - SELECT - Facility { - fid, discount, name - } FILTER .fid = 1; - ''', - [{ - 'fid': 1, - 'discount': 5 / 25, - 'name': 'Tennis Court 2' - }] - ) - # query object with link - await self.assert_query_result( - r''' - SELECT Facility { - fid, name, - bookedby: { - mid, - fullname := .surname ++ '.' ++ .firstname - } ORDER BY .mid, - } FILTER .fid=1; - ''', - [{ - 'fid': 1, - 'name': 'Tennis Court 2', - 'bookedby': {'mid': 0, 'fullname': 'GUEST.GUEST'} - }] - ) - - with self.assertRaisesRegex( - edgedb.ConstraintViolationError, - regex='deletion of .* is prohibited by link target policy' - ): - await self.con.execute( - f''' - delete {link_to} FILTER .mid=0; - ''' - ) - with self.assertRaisesRegex( - edgedb.InvalidTypeError, - regex="operator '=' cannot be applied to operands of type .*" - ): - await self.con.execute( - r''' - select Facility FILTER .fid = '1'; - ''' - ) - with self.assertRaisesRegex( - edgedb.InvalidTypeError, - regex="operator '=' cannot be applied to operands of type .*" - ): - await self.con.execute( - r''' - select Facility FILTER .bookedby.mid = '1'; - ''' - ) - finally: - await self.con.execute( - ''' - Drop type Facility; - ''' - ) + await self.link_outer_inner_single_link_with_prop() async def test_link_outer_inner_multi_link_with_prop(self): - link_to = "Person" - self.assertTrue(self.create_facility(link_to, 'multi')) - try: - # query object without link - await self.assert_query_result( - r''' - SELECT - Facility { - fid, discount, name - } FILTER .fid = 1; - ''', - [{ - 'fid': 1, - 'discount': 5 / 25, - 'name': 'Tennis Court 2' - }] - ) - # query object with link - await self.assert_query_result( - r''' - SELECT distinct - Facility { - fid, name, - bookedby: { - mid, - fullname := .surname ++ '.' ++ .firstname - } ORDER BY .mid, - } FILTER .fid=1; - ''', - [{ - 'fid': 1, - 'name': 'Tennis Court 2', - 'bookedby': [ - {'mid': 0, 'fullname': 'GUEST.GUEST'}, - {'mid': 1, 'fullname': 'Smith.Darren'}, - {'mid': 2, 'fullname': 'Smith.Tracy'}, - {'mid': 3, 'fullname': 'Rownam.Tim'}, - {'mid': 4, 'fullname': 'Joplette.Janice'}, - {'mid': 5, 'fullname': 'Butters.Gerald'}, - {'mid': 6, 'fullname': 'Tracy.Burton'}, - {'mid': 7, 'fullname': 'Dare.Nancy'}, - {'mid': 8, 'fullname': 'Boothe.Tim'}, - {'mid': 9, 'fullname': 'Stibbons.Ponder'}, - {'mid': 10, 'fullname': 'Owen.Charles'}, - {'mid': 11, 'fullname': 'Jones.David'}, - {'mid': 12, 'fullname': 'Baker.Anne'}, - {'mid': 13, 'fullname': 'Farrell.Jemima'}, - {'mid': 14, 'fullname': 'Smith.Jack'}, - {'mid': 15, 'fullname': 'Bader.Florence'}, - {'mid': 16, 'fullname': 'Baker.Timothy'}, - {'mid': 24, 'fullname': 'Sarwin.Ramnaresh'}, - {'mid': 27, 'fullname': 'Rumney.Henrietta'}, - {'mid': 28, 'fullname': 'Farrell.David'}, - {'mid': 30, 'fullname': 'Purview.Millicent'}, - {'mid': 35, 'fullname': 'Hunt.John'}, - ] - }] - ) - - with self.assertRaisesRegex( - edgedb.ConstraintViolationError, - regex='deletion of .* is prohibited by link target policy' - ): - await self.con.execute( - f''' - delete {link_to} FILTER .mid=5; - ''' - ) - with self.assertRaisesRegex( - edgedb.InvalidTypeError, - regex="operator '=' cannot be applied to operands of type .*" - ): - await self.con.execute( - r''' - select Facility FILTER .fid = '1'; - ''' - ) - with self.assertRaisesRegex( - edgedb.InvalidTypeError, - regex="operator '=' cannot be applied to operands of type .*" - ): - await self.con.execute( - r''' - select Facility FILTER .bookedby.mid = '1'; - ''' - ) - finally: - await self.con.execute( - ''' - Drop type Facility; - ''' - ) + await self.link_outer_inner_multi_link_with_prop() async def test_link_inner_outer(self): - self.assertTrue(self.create_member_from_outer()) - try: - with self.assertRaisesRegex( - edgedb.SchemaDefinitionError, - regex="target_property is required in " - "create link 'member' from object type 'default::NameList' " - "to external object type 'default::Member'." - ): - await self.con.execute( - ''' - create type NameList{ - create property _id -> std::int32 { - create constraint std::exclusive; - }; - create link member -> Member; - create property alive -> std::bool; - }; - ''' - ) - await self.con.execute( - ''' - create type NameList{ - create property _id -> std::int32 { - create constraint std::exclusive; - }; - create link member -> Member{ - on _id to mid - }; - create property alive -> std::bool; - }; - ''' - ) - with self.assertRaisesRegex( - edgedb.SchemaDefinitionError, - regex="target_property is required in " - "alter link 'member' from object type 'default::NameList' " - "to external object type 'default::Member'." - ): - await self.con.execute( - ''' - alter type NameList{ - alter link member {on id to id}; - }; - ''' - ) - await self.con.execute( - ''' - insert NameList{ - _id := 0, alive := true, - member:= (select Member filter .mid = 0) - }; - insert NameList{ - _id := 1, alive := false, - member:= (select Member filter .mid = 1) - }; - insert NameList{ - _id := 2, alive := true, - member:= (select Member filter .mid = 2) - }; - ''' - ) - # query object with link - await self.assert_query_result( - r''' - SELECT NameList - { - member: { - mid, - fullname := .surname ++ '.' ++ .firstname - } Limit 1 - } FILTER .alive=true - ORDER BY .member.mid; - ''', - [ - {'member': {'mid': 0, 'fullname': 'GUEST.GUEST'}}, - {'member': {'mid': 2, 'fullname': 'Smith.Tracy'}}, - ] - ) - - with self.assertRaisesRegex( - edgedb.InvalidTypeError, - regex="operator '=' cannot be applied to operands of type .*" - ): - await self.con.execute( - r''' - select NameList FILTER .member.mid = '1'; - ''' - ) - finally: - await self.con.execute( - ''' - Drop type NameList; - Drop type Member; - ''' - ) + await self.link_inner_outer() async def test_link_inner_outer_on_source_delete(self): - self.assertTrue(self.create_member_from_outer()) - try: - await self.con.execute( - ''' - create type NameList{ - create property _id -> std::int32 { - create constraint std::exclusive; - }; - create link member -> Member { - on _id to mid; - on source delete delete target; - }; - create property alive -> std::bool; - }; - ''' - ) - await self.con.execute( - ''' - insert NameList{ - _id := 0, alive := true, - member:= (select Member filter .mid = 0) - }; - insert NameList{ - _id := 1, alive := false, - member:= (select Member filter .mid = 1) - }; - insert NameList{ - _id := 2, alive := true, - member:= (select Member filter .mid = 2) - }; - ''' - ) - # query object with link - await self.assert_query_result( - r''' - SELECT NameList - { - member: { - mid, - fullname := .surname ++ '.' ++ .firstname - } Limit 1 - } FILTER .alive=true - ORDER BY .member.mid; - ''', - [ - { - 'member': {'mid': 0, 'fullname': 'GUEST.GUEST'} - }, - { - 'member': {'mid': 2, 'fullname': 'Smith.Tracy'} - }, - ] - ) - await self.assert_query_result( - r''' - SELECT count(Member); - ''', - [ - 32 - ] - ) - await self.con.execute( - f''' - delete NameList FILTER ._id=0; - ''' - ) - await self.assert_query_result( - r''' - SELECT NameList - { - member: { - mid, - fullname := .surname ++ '.' ++ .firstname - } Limit 1 - } FILTER .alive=true; - ''', - [ - { - 'member': {'mid': 2, 'fullname': 'Smith.Tracy'} - }, - ] - ) - await self.assert_query_result( - r''' - SELECT count(Member); - ''', - [ - 32 - ] - ) - finally: - await self.con.execute( - ''' - Drop type NameList; - Drop type Member; - ''' - ) + await self.link_inner_outer_on_source_delete() async def test_link_inner_outer_multi_link_on_source_delete(self): - self.assertTrue(self.create_member_from_outer()) - try: - await self.con.execute( - ''' - create type NameList{ - create property _id -> std::int32 { - create constraint std::exclusive; - }; - create multi link member -> Member { - on _id to mid; - on source delete delete target; - }; - create property alive -> std::bool; - }; - ''' - ) - await self.con.execute( - ''' - insert NameList{ - _id := 0, alive := true, - member:= (select Member filter .mid = 0) - }; - insert NameList{ - _id := 1, alive := false, - member:= (select Member filter .mid = 1) - }; - insert NameList{ - _id := 2, alive := true, - member:= (select Member filter .mid = 2) - }; - ''' - ) - # query object with link - await self.assert_query_result( - r''' - SELECT NameList - { - member: { - mid, - fullname := .surname ++ '.' ++ .firstname - } Limit 1 - } FILTER .alive=true; - ''', - [ - { - 'member': [{'mid': 0, 'fullname': 'GUEST.GUEST'}] - }, - { - 'member': [{'mid': 2, 'fullname': 'Smith.Tracy'}] - }, - ] - ) - await self.assert_query_result( - r''' - SELECT count(Member); - ''', - [ - 32 - ] - ) - await self.con.execute( - f''' - delete NameList FILTER ._id=0; - ''' - ) - await self.assert_query_result( - r''' - SELECT NameList - { - member: { - mid, - fullname := .surname ++ '.' ++ .firstname - } Limit 1 - } FILTER .alive=true; - ''', - [ - { - 'member': [{'mid': 2, 'fullname': 'Smith.Tracy'}] - }, - ] - ) - await self.assert_query_result( - r''' - SELECT count(Member); - ''', - [ - 32 - ] - ) + await self.link_inner_outer_multi_link_on_source_delete() - with self.assertRaisesRegex( - edgedb.InvalidTypeError, - regex="operator '=' cannot be applied to operands of type .*" - ): - await self.con.execute( - r''' - select NameList FILTER .member.mid = '1'; - ''' - ) - finally: - await self.con.execute( - ''' - Drop type NameList; - Drop type Member; - ''' - ) + async def test_dml_reject(self): + await self.dml_reject() + + +class TestHttpCreateTypeDumpRestore(TestHttpCreateType, server_tb.StableDumpTestCase): + async def prepare(self): + await self.prepare_external_db(dbname=f"{self.get_database_name()}_restored") + + async def test_link_outer_outer_single_link(self): + await self.check_dump_restore( + check_method=HttpCreateTypeMixin.link_outer_outer_single_link, + restore_db_prepare=self.prepare + ) + + async def test_link_outer_outer_single_link_with_prop(self): + await self.check_dump_restore( + check_method=HttpCreateTypeMixin.link_outer_outer_single_link_with_prop, + restore_db_prepare=self.prepare + ) + + async def test_link_outer_outer_multi_link_with_prop(self): + await self.check_dump_restore( + check_method=HttpCreateTypeMixin.link_outer_outer_multi_link_with_prop, + restore_db_prepare=self.prepare + ) + + async def test_link_outer_inner_single_link(self): + await self.check_dump_restore( + check_method=HttpCreateTypeMixin.link_outer_inner_single_link, + restore_db_prepare=self.prepare + ) + + async def test_link_outer_inner_single_link_with_prop(self): + await self.check_dump_restore( + check_method=HttpCreateTypeMixin.link_outer_inner_single_link_with_prop, + restore_db_prepare=self.prepare + ) + + async def test_link_outer_inner_multi_link_with_prop(self): + await self.check_dump_restore( + check_method=HttpCreateTypeMixin.link_outer_inner_multi_link_with_prop, + restore_db_prepare=self.prepare + ) + + async def test_link_inner_outer(self): + await self.check_dump_restore( + check_method=HttpCreateTypeMixin.link_inner_outer, + restore_db_prepare=self.prepare + ) + self.new_outter_type.remove("NameList") + + async def test_link_inner_outer_on_source_delete(self): + await self.check_dump_restore( + check_method=HttpCreateTypeMixin.link_inner_outer_on_source_delete, + restore_db_prepare=self.prepare + ) + self.new_outter_type.remove("NameList") + + async def test_link_inner_outer_multi_link_on_source_delete(self): + await self.check_dump_restore( + check_method=HttpCreateTypeMixin.link_inner_outer_multi_link_on_source_delete, + restore_db_prepare=self.prepare + ) + self.new_outter_type.remove("NameList") async def test_dml_reject(self): - self.assertTrue(self.create_member_from_outer()) - try: - res = await self.con._fetchall_json( - "describe module default;" - ) - with self.assertRaisesRegex( - edgedb.QueryError, - regex='External .* is read-only.' - ): - await self.con.execute( - ''' - delete Member filter .mid = 0; - ''' - ) - with self.assertRaisesRegex( - edgedb.QueryError, - regex='External .* is read-only.' - ): - await self.con.execute( - ''' - update Member filter .mid = 0 - set {mid := 999}; - ''' - ) - with self.assertRaisesRegex( - edgedb.QueryError, - regex='External .* is read-only.' - ): - await self.con.execute( - ''' - insert Member - { - mid := 999, surname := 'T', firstname := 'E', - address := 'Space', phone_number := '31415926', - joindate := '1970-01-01 00:00:00' - }; - ''' - ) - finally: - await self.con.execute( - ''' - Drop type Member; - ''' - ) + await self.check_dump_restore( + check_method=HttpCreateTypeMixin.dml_reject, + restore_db_prepare=self.prepare + ) From 2dfed87125a7cb44fb67a163b269298d94578031 Mon Sep 17 00:00:00 2001 From: JMPRIEST Date: Wed, 17 May 2023 18:47:43 +0800 Subject: [PATCH 10/20] =?UTF-8?q?:construction:=20=E7=BC=96=E5=86=99namesp?= =?UTF-8?q?ace=E4=BF=A1=E6=81=AF=E5=9C=A8compile=E8=BF=87=E7=A8=8B?= =?UTF-8?q?=E4=B8=AD=E8=B5=B7=E6=95=88=E7=9A=84=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- edb/graphql/compiler.py | 10 +- edb/graphql/extension.pyx | 2 + edb/pgsql/codegen.py | 11 +- edb/pgsql/compiler/__init__.py | 7 +- edb/pgsql/dbops/base.py | 24 ++- edb/pgsql/dbops/ddl.py | 14 +- edb/pgsql/dbops/namespace.py | 5 +- edb/pgsql/dbops/roles.py | 2 +- edb/pgsql/delta.py | 131 +++++++----- edb/pgsql/metaschema.py | 12 +- edb/schema/defines.py | 2 + edb/schema/delta.py | 9 +- edb/schema/namespace.py | 21 +- edb/server/bootstrap.py | 11 + edb/server/compiler/compiler.py | 56 ++++-- edb/server/compiler/dbstate.py | 9 +- edb/server/compiler_pool/pool.py | 195 +++++++++++------- edb/server/compiler_pool/state.py | 3 +- edb/server/compiler_pool/worker.py | 43 ++-- edb/server/dbview/dbview.pxd | 1 + edb/server/dbview/dbview.pyx | 49 +++-- edb/server/defines.py | 1 + edb/server/pgcon/pgcon.pyx | 6 +- edb/server/protocol/binary.pyx | 1 + edb/server/protocol/binary_v0.pyx | 41 ++-- edb/server/protocol/execute.pyx | 12 ++ edb/server/protocol/notebook_ext.pyx | 1 + edb/server/server.py | 287 ++++++++++++++++----------- 28 files changed, 623 insertions(+), 343 deletions(-) diff --git a/edb/graphql/compiler.py b/edb/graphql/compiler.py index 781a9edc82f..a5fbf16e7ce 100644 --- a/edb/graphql/compiler.py +++ b/edb/graphql/compiler.py @@ -30,7 +30,7 @@ GQLCoreCache: Dict[ - str, + Tuple[str, str], Dict[ (s_schema.FlatSchema, uuid.UUID, s_schema.FlatSchema, str), graphql.GQLCoreSchema @@ -40,19 +40,20 @@ def _get_gqlcore( dbname: str, + namespace: str, std_schema: s_schema.FlatSchema, user_schema: s_schema.FlatSchema, global_schema: s_schema.FlatSchema, module: str = None ) -> graphql.GQLCoreSchema: key = (std_schema, user_schema.version_id, global_schema, module) - if cache := GQLCoreCache.get(dbname): + if cache := GQLCoreCache.get((dbname, namespace)): if key in cache: return cache[key] else: cache.clear() else: - cache = GQLCoreCache.setdefault(dbname, {}) + cache = GQLCoreCache.setdefault((dbname, namespace), {}) core = graphql.GQLCoreSchema( s_schema.ChainedSchema( @@ -68,6 +69,7 @@ def _get_gqlcore( def compile_graphql( dbname: str, + namespace: str, std_schema: s_schema.FlatSchema, user_schema: s_schema.FlatSchema, global_schema: s_schema.FlatSchema, @@ -88,7 +90,7 @@ def compile_graphql( else: ast = graphql.parse_tokens(gql, tokens) - gqlcore = _get_gqlcore(dbname, std_schema, user_schema, global_schema, module) + gqlcore = _get_gqlcore(dbname, namespace, std_schema, user_schema, global_schema, module) return graphql.translate_ast( gqlcore, diff --git a/edb/graphql/extension.pyx b/edb/graphql/extension.pyx index e4d19a15352..00c20b0d78b 100644 --- a/edb/graphql/extension.pyx +++ b/edb/graphql/extension.pyx @@ -229,6 +229,7 @@ async def compile( compiler_pool = server.get_compiler_pool() return await compiler_pool.compile_graphql( db.name, + db.namespace, db.user_schema, server.get_global_schema(), db.reflection_cache, @@ -373,6 +374,7 @@ async def _execute( dbname=db.name, query_cache=False, protocol_version=edbdef.CURRENT_PROTOCOL, + db=db.namespace ) pgcon = await server.acquire_pgcon(db.name) diff --git a/edb/pgsql/codegen.py b/edb/pgsql/codegen.py index 586aeac7ab9..8aa0026fed4 100644 --- a/edb/pgsql/codegen.py +++ b/edb/pgsql/codegen.py @@ -26,6 +26,7 @@ from edb.common.ast import codegen from edb.common import exceptions from edb.common import markup +from edb.schema import defines class SQLSourceGeneratorContext(markup.MarkupExceptionContext): @@ -62,10 +63,11 @@ def __init__(self, msg, *, node=None, details=None, hint=None): class SQLSourceGenerator(codegen.SourceGenerator): - def __init__(self, *args, reordered: bool=False, **kwargs): + def __init__(self, *args, reordered: bool=False, namespace: str = defines.DEFAULT_NS, **kwargs): super().__init__(*args, **kwargs) self.param_index: dict[object, int] = {} self.reordered = reordered + self.namespace = namespace @classmethod def to_source( @@ -118,7 +120,12 @@ def visit_Relation(self, node): if node.schemaname is None: self.write(common.qname(node.name)) else: - self.write(common.qname(node.schemaname, node.name)) + if self.namespace == defines.DEFAULT_NS: + self.write(common.qname(node.schemaname, node.name)) + elif node.schemaname in ['edgedbext', 'edgedb', 'edgedbss', 'edgedbpub', 'edgedbstd', 'edgedbinstdata']: + self.write(common.qname(f"{self.namespace}_{node.schemaname}", node.name)) + else: + self.write(common.qname(node.schemaname, node.name)) def _visit_values_expr(self, node): self.new_lines = 1 diff --git a/edb/pgsql/compiler/__init__.py b/edb/pgsql/compiler/__init__.py index c5d479ed86b..d54eb109351 100644 --- a/edb/pgsql/compiler/__init__.py +++ b/edb/pgsql/compiler/__init__.py @@ -27,6 +27,7 @@ from edb.common import exceptions as edgedb_error from edb.ir import ast as irast +from edb.schema import defines from edb.pgsql import ast as pgast from edb.pgsql import codegen as pgcodegen @@ -147,6 +148,7 @@ def compile_ir_to_sql( expected_cardinality_one: bool=False, pretty: bool=True, backend_runtime_params: Optional[pgparams.BackendRuntimeParams]=None, + namespace: str = defines.DEFAULT_NS ) -> Tuple[str, Dict[str, pgast.Param]]: qtree = compile_ir_to_sql_tree( @@ -172,7 +174,7 @@ def compile_ir_to_sql( argmap = {} # Generate query text - sql_text = run_codegen(qtree, pretty=pretty) + sql_text = run_codegen(qtree, pretty=pretty, namespace=namespace) if ( # pragma: no cover debug.flags.edgeql_compile or debug.flags.edgeql_compile_sql_text @@ -194,8 +196,9 @@ def run_codegen( *, pretty: bool=True, reordered: bool=False, + namespace: str = defines.DEFAULT_NS ) -> str: - codegen = pgcodegen.SQLSourceGenerator(pretty=pretty, reordered=reordered) + codegen = pgcodegen.SQLSourceGenerator(pretty=pretty, reordered=reordered, namespace=namespace) try: codegen.visit(qtree) except pgcodegen.SQLSourceGeneratorError as e: # pragma: no cover diff --git a/edb/pgsql/dbops/base.py b/edb/pgsql/dbops/base.py index aae34aa6d24..50bcd2e4928 100644 --- a/edb/pgsql/dbops/base.py +++ b/edb/pgsql/dbops/base.py @@ -72,12 +72,13 @@ class PLExpression(str): class SQLBlock: - def __init__(self): + def __init__(self, namespace_prefix: str = ''): self.commands = [] self._transactional = True + self.namespace_prefix = namespace_prefix def add_block(self): - block = PLTopBlock() + block = PLTopBlock(self.namespace_prefix) self.add_command(block) return block @@ -112,8 +113,11 @@ def is_transactional(self) -> bool: class PLBlock(SQLBlock): - def __init__(self, top_block, level): - super().__init__() + def __init__(self, top_block, level, namespace_prefix: str = ''): + if top_block is None: + super().__init__(namespace_prefix) + else: + super().__init__(top_block.namespace_prefix) self.top_block = top_block self.varcounter = collections.defaultdict(int) self.shared_vars = set() @@ -132,7 +136,7 @@ def get_top_block(self) -> PLTopBlock: return self.top_block def add_block(self): - block = PLBlock(top_block=self.top_block, level=self.level + 1) + block = PLBlock(top_block=self.top_block, level=self.level + 1, namespace_prefix=self.namespace_prefix) self.add_command(block) return block @@ -225,9 +229,9 @@ def declare_var( self, type_name: Union[str, Tuple[str, str]], *, - var_name: str='', - var_name_prefix: str='v', - shared: bool=False, + var_name: str = '', + var_name_prefix: str = 'v', + shared: bool = False, ) -> str: if shared: if not var_name: @@ -244,8 +248,8 @@ def declare_var( class PLTopBlock(PLBlock): - def __init__(self): - super().__init__(top_block=None, level=0) + def __init__(self, namespace_prefix: str = ''): + super().__init__(top_block=None, level=0, namespace_prefix=namespace_prefix) self.declare_var('text', var_name='_dummy_text', shared=True) def add_block(self): diff --git a/edb/pgsql/dbops/ddl.py b/edb/pgsql/dbops/ddl.py index 4618b17a2fd..cd27142d91b 100644 --- a/edb/pgsql/dbops/ddl.py +++ b/edb/pgsql/dbops/ddl.py @@ -115,7 +115,7 @@ def code(self, block: base.PLBlock) -> str: if is_shared: return textwrap.dedent(f'''\ SELECT - edgedb.shobj_metadata( + {block.namespace_prefix}edgedb.shobj_metadata( {objoid}, {classoid}::regclass::text ) @@ -123,7 +123,7 @@ def code(self, block: base.PLBlock) -> str: elif objsubid: return textwrap.dedent(f'''\ SELECT - edgedb.col_metadata( + {block.namespace_prefix}edgedb.col_metadata( {objoid}, {objsubid} ) @@ -131,7 +131,7 @@ def code(self, block: base.PLBlock) -> str: else: return textwrap.dedent(f'''\ SELECT - edgedb.obj_metadata( + {block.namespace_prefix}edgedb.obj_metadata( {objoid}, {classoid}::regclass::text, ) @@ -149,7 +149,7 @@ def code(self, block: base.PLBlock) -> str: SELECT json FROM - edgedbinstdata.instdata + {block.namespace_prefix}edgedbinstdata.instdata WHERE key = {ql(key)} ''') @@ -211,7 +211,7 @@ def code(self, block: base.PLBlock) -> str: metadata = ql(json.dumps(self.metadata)) return textwrap.dedent(f'''\ UPDATE - edgedbinstdata.instdata + {block.namespace_prefix}edgedbinstdata.instdata SET json = {metadata} WHERE @@ -260,7 +260,7 @@ def code(self, block: base.PLBlock) -> str: return textwrap.dedent(f'''\ UPDATE - edgedbinstdata.instdata + {block.namespace_prefix}edgedbinstdata.instdata SET json = {json_v} || {meta_v} WHERE @@ -329,7 +329,7 @@ def code(self, block: base.PLBlock) -> str: json_v, meta_v = self._merge(block) return textwrap.dedent(f'''\ UPDATE - edgedbinstdata.instdata + {block.namespace_prefix}edgedbinstdata.instdata SET json = {json_v} || {meta_v} WHERE diff --git a/edb/pgsql/dbops/namespace.py b/edb/pgsql/dbops/namespace.py index 2b1d87ff7c3..6cb15ea37e4 100644 --- a/edb/pgsql/dbops/namespace.py +++ b/edb/pgsql/dbops/namespace.py @@ -23,7 +23,8 @@ from . import base from . import ddl -from ..common import quote_ident as qi +from edb.pgsql.common import quote_ident as qi +from edb.schema.defines import DEFAULT_NS class NameSpace(base.DBObject): @@ -39,7 +40,7 @@ def get_type(self): return 'SCHEMA' def get_id(self): - return qi(f"{self.name}_edgedb") + return qi(f"{self.name}_edgedb") if self.name != DEFAULT_NS else qi("edgedb") def is_shared(self) -> bool: return False diff --git a/edb/pgsql/dbops/roles.py b/edb/pgsql/dbops/roles.py index a5505869737..ae7fa6c70d2 100644 --- a/edb/pgsql/dbops/roles.py +++ b/edb/pgsql/dbops/roles.py @@ -153,7 +153,7 @@ def generate_extra(self, block: base.PLBlock) -> None: value = json.dumps(self.object.single_role_metadata) query = base.Query( f''' - UPDATE edgedbinstdata.instdata + UPDATE {block.namespace_prefix}edgedbinstdata.instdata SET json = {ql(value)}::jsonb WHERE key = 'single_role_metadata' ''' diff --git a/edb/pgsql/delta.py b/edb/pgsql/delta.py index 806379fdb9c..4b2f9a47644 100644 --- a/edb/pgsql/delta.py +++ b/edb/pgsql/delta.py @@ -34,7 +34,7 @@ from edb.edgeql import qltypes as ql_ft from edb.edgeql import compiler as qlcompiler -from edb.schema import annos as s_anno +from edb.schema import annos as s_anno, defines from edb.schema import casts as s_casts from edb.schema import scalars as s_scalars from edb.schema import objtypes as s_objtypes @@ -362,7 +362,7 @@ def apply( (SELECT version::text FROM - edgedb."_SchemaSchemaVersion" + {context.pg_schema('edgedb')}."_SchemaSchemaVersion" FOR UPDATE), {ql(str(expected_ver))} )), @@ -370,7 +370,7 @@ def apply( msg => ( 'Cannot serialize DDL: ' || (SELECT version::text FROM - edgedb."_SchemaSchemaVersion") + {context.pg_schema('edgedb')}."_SchemaSchemaVersion") ) ) INTO _dummy_text @@ -467,7 +467,7 @@ def apply( SELECT json FROM - edgedbinstdata.instdata + {context.pg_schema('edgedbinstdata')}.instdata WHERE key = {ql(key)} FOR UPDATE @@ -520,7 +520,7 @@ def apply( msg => ( 'Cannot serialize global DDL: ' || (SELECT version::text FROM - edgedb."_SysGlobalSchemaVersion") + {context.pg_schema('edgedb')}."_SysGlobalSchemaVersion") ) ) INTO _dummy_text @@ -537,7 +537,7 @@ def apply( (SELECT version::text FROM - edgedb."_SysGlobalSchemaVersion" + {context.pg_schema('edgedb')}."_SysGlobalSchemaVersion" ), {ql(str(expected_ver))} )), @@ -545,7 +545,7 @@ def apply( msg => ( 'Cannot serialize global DDL: ' || (SELECT version::text FROM - edgedb."_SysGlobalSchemaVersion") + {context.pg_schema('edgedb')}."_SysGlobalSchemaVersion") ) ) INTO _dummy_text @@ -1171,7 +1171,7 @@ def compile_edgeql_overloaded_function_body( target AS ancestor, index FROM - edgedb."_SchemaObjectType__ancestors" + {context.pg_schema('edgedb')}."_SchemaObjectType__ancestors" WHERE source = {qi(type_param_name)} ) a WHERE ancestor IN ({impl_ids}) @@ -3541,8 +3541,9 @@ def apply( self.apply_scheduled_inhview_updates(schema, context) if is_external: - view_name = ('edgedbpub', str(objtype.id)) - view_name_t = ('edgedbpub', str(objtype.id) + '_t') + schema_name = context.pg_schema('edgedbpub') + view_name = (schema_name, str(objtype.id)) + view_name_t = (schema_name, str(objtype.id) + '_t') self.pgops.add( dbops.DropView( name=view_name, @@ -5160,8 +5161,9 @@ def apply( self.apply_scheduled_inhview_updates(schema, context) if has_extern_table: - view_name = ('edgedbpub', str(link.id)) - view_name_t = ('edgedbpub', str(link.id) + '_t') + schema_name = context.pg_schema('edgedbpub') + view_name = (schema_name, str(link.id)) + view_name_t = (schema_name, str(link.id) + '_t') self.pgops.add( dbops.DropView( name=view_name, @@ -5799,24 +5801,31 @@ def get_trigger_proc_name(self, schema, target, return common.get_backend_name( schema, target, catenate=False, aspect=aspect) - def get_trigger_proc_text(self, target, links, *, - disposition, inline, schema): + def get_trigger_proc_text( + self, target, links, *, disposition, inline, schema, namespace + ): if inline: return self._get_inline_link_trigger_proc_text( - target, links, disposition=disposition, schema=schema) + target, links, disposition=disposition, schema=schema, namespace=namespace + ) else: return self._get_outline_link_trigger_proc_text( - target, links, disposition=disposition, schema=schema) + target, links, disposition=disposition, schema=schema, namespace=namespace + ) - def _get_dunder_type_trigger_proc_text(self, target, *, schema): + def _get_dunder_type_trigger_proc_text(self, target, *, schema, namespace): + if namespace == defines.DEFAULT_NS: + ns_prefix = '' + else: + ns_prefix = namespace + '_' body = textwrap.dedent('''\ SELECT CASE WHEN tp.builtin - THEN 'edgedbstd' - ELSE 'edgedbpub' + THEN '{ns_prefix}edgedbstd' + ELSE '{ns_prefix}edgedbpub' END AS sname INTO schema_name - FROM edgedb."_SchemaType" as tp + FROM {ns_prefix}edgedb."_SchemaType" as tp WHERE tp.id = OLD.id; SELECT EXISTS ( @@ -5840,10 +5849,10 @@ def _get_dunder_type_trigger_proc_text(self, target, *, schema): MESSAGE = 'deletion of {tgtname} (' || OLD.id || ') is prohibited by link target policy', DETAIL = 'Object is still referenced in link __type__' - || ' of ' || edgedb._get_schema_object_name(OLD.id) || ' (' + || ' of ' || {ns_prefix}._get_schema_object_name(OLD.id) || ' (' || OLD.id || ').'; END IF; - '''.format(tgtname=target.get_displayname(schema))) + '''.format(tgtname=target.get_displayname(schema), ns_prefix=ns_prefix)) text = textwrap.dedent('''\ DECLARE @@ -5859,7 +5868,12 @@ def _get_dunder_type_trigger_proc_text(self, target, *, schema): return text def _get_outline_link_trigger_proc_text( - self, target, links, *, disposition, schema): + self, target, links, *, disposition, schema, namespace + ): + if namespace == defines.DEFAULT_NS: + ns_prefix = '' + else: + ns_prefix = namespace + '_' chunks = [] @@ -5918,11 +5932,11 @@ def _declare_var(var_prefix, index, var_type): IF FOUND THEN SELECT - edgedb.shortname_from_fullname(link.name), - edgedb._get_schema_object_name(link.{far_endpoint}) + {ns_prefix}edgedb.shortname_from_fullname(link.name), + {ns_prefix}edgedb._get_schema_object_name(link.{far_endpoint}) INTO linkname, endname FROM - edgedb."_SchemaLink" AS link + {ns_prefix}edgedb."_SchemaLink" AS link WHERE link.id = link_type_id; RAISE foreign_key_violation @@ -5943,6 +5957,7 @@ def _declare_var(var_prefix, index, var_type): tgtname=target.get_displayname(schema), near_endpoint=near_endpoint, far_endpoint=far_endpoint, + ns_prefix=ns_prefix ) chunks.append(text) @@ -6185,7 +6200,8 @@ def _resolve_link_group( return tuple(group_var) def _get_inline_link_trigger_proc_text( - self, target, links, *, disposition, schema): + self, target, links, *, disposition, schema + ): chunks = [] @@ -6238,11 +6254,11 @@ def _get_inline_link_trigger_proc_text( IF FOUND THEN SELECT - edgedb.shortname_from_fullname(link.name), - edgedb._get_schema_object_name(link.{far_endpoint}) + {ns_prefix}.shortname_from_fullname(link.name), + {ns_prefix}._get_schema_object_name(link.{far_endpoint}) INTO linkname, endname FROM - edgedb."_SchemaLink" AS link + {ns_prefix}."_SchemaLink" AS link WHERE link.id = link_type_id; RAISE foreign_key_violation @@ -6528,12 +6544,14 @@ def apply( if links or modifications: self._update_action_triggers( - schema, source, links, disposition='source') + schema, source, links, disposition='source', namespace=context.namespace + ) if inline_links or modifications: self._update_action_triggers( schema, source, inline_links, - inline=True, disposition='source') + inline=True, disposition='source', namespace=context.namespace + ) # All descendants of affected targets also need to have their # triggers updated, so track them down. @@ -6620,27 +6638,31 @@ def apply( key=lambda l: l.get_name(schema)) if dunder_type_links: - self._update_dunder_type_link_triggers(schema, target) + self._update_dunder_type_link_triggers(schema, target, context.namespace) if links or modifications: self._update_action_triggers( - schema, target, links, disposition='target') + schema, target, links, disposition='target', namespace=context.namespace + ) if inline_links or modifications: self._update_action_triggers( schema, target, inline_links, - disposition='target', inline=True) + disposition='target', inline=True, namespace=context.namespace + ) if deferred_links or modifications: self._update_action_triggers( schema, target, deferred_links, - disposition='target', deferred=True) + disposition='target', deferred=True, namespace=context.namespace + ) if deferred_inline_links or modifications: self._update_action_triggers( schema, target, deferred_inline_links, disposition='target', deferred=True, - inline=True) + inline=True, namespace=context.namespace + ) return schema @@ -6648,6 +6670,7 @@ def _update_dunder_type_link_triggers( self, schema, objtype: s_objtypes.ObjectType, + namespace: str ) -> None: table_name = common.get_backend_name( schema, objtype, catenate=False) @@ -6664,7 +6687,8 @@ def _update_dunder_type_link_triggers( is_constraint=True, inherit=True, deferred=False) proc_text = self._get_dunder_type_trigger_proc_text( - objtype, schema=schema) + objtype, schema=schema, namespace=namespace + ) trig_func = dbops.Function( name=proc_name, text=proc_text, volatility='volatile', @@ -6679,13 +6703,15 @@ def _update_dunder_type_link_triggers( )) def _update_action_triggers( - self, - schema, - objtype: s_objtypes.ObjectType, - links: List[s_links.Link], *, - disposition: str, - deferred: bool=False, - inline: bool=False) -> None: + self, + schema, + objtype: s_objtypes.ObjectType, + links: List[s_links.Link], *, + disposition: str, + namespace: str, + deferred: bool = False, + inline: bool = False, + ) -> None: table_name = common.get_backend_name( schema, objtype, catenate=False) @@ -6706,7 +6732,8 @@ def _update_action_triggers( if links: proc_text = self.get_trigger_proc_text( objtype, links, disposition=disposition, - inline=inline, schema=schema) + inline=inline, schema=schema, namespace=namespace + ) trig_func = dbops.Function( name=proc_name, text=proc_text, volatility='volatile', @@ -6763,8 +6790,9 @@ def collect_external_objects( view_def = context.external_view[key] if context.restoring_external: - self.external_views.append(dbops.View(query=view_def, name=('edgedbpub', str(obj.id)))) - self.external_views.append(dbops.View(query=view_def, name=('edgedbpub', str(obj.id) + '_t'))) + schema_name = context.pg_schema('edgedbpub') + self.external_views.append(dbops.View(query=view_def, name=(schema_name, str(obj.id)))) + self.external_views.append(dbops.View(query=view_def, name=(schema_name, str(obj.id) + '_t'))) return columns = [] @@ -6785,7 +6813,7 @@ def collect_external_objects( ptrname = ptr.get_shortname(schema).name if ptrname == 'id': - columns.append("edgedbext.uuid_generate_v1mc() AS id") + columns.append(f"{context.pg_schema('edgedbext')}.uuid_generate_v1mc() AS id") elif ptrname == '__type__': columns.append(f"'{(str(obj.id))}'::uuid AS __type__") elif has_link_table: @@ -6814,8 +6842,9 @@ def collect_external_objects( if join_link_table is not None: query += f", (SELECT * FROM {join_link_table.relation}) AS INNER_T " \ f"where INNER_T.{join_link_table.columns['source']} = SOURCE_T.{source_identity}" - self.external_views.append(dbops.View(query=query, name=('edgedbpub', str(obj.id)))) - self.external_views.append(dbops.View(query=query, name=('edgedbpub', str(obj.id) + '_t'))) + schema_name = context.pg_schema('edgedbpub') + self.external_views.append(dbops.View(query=query, name=(schema_name, str(obj.id)))) + self.external_views.append(dbops.View(query=query, name=(schema_name, str(obj.id) + '_t'))) def apply( diff --git a/edb/pgsql/metaschema.py b/edb/pgsql/metaschema.py index 72197f3fef0..90c869c0d4a 100644 --- a/edb/pgsql/metaschema.py +++ b/edb/pgsql/metaschema.py @@ -4008,8 +4008,18 @@ async def bootstrap( config_spec: edbconfig.Spec ) -> None: commands = dbops.CommandGroup() + default_ns = dbops.NameSpace( + defines.DEFAULT_NS, + metadata=dict( + id=str(uuidgen.uuid1mc()), + builtin=False, + name=defines.DEFAULT_NS, + internal=False + ), + ) commands.add_commands([ dbops.CreateSchema(name='edgedb'), + dbops.SetMetadata(default_ns, default_ns.metadata), dbops.CreateSchema(name='edgedbss'), dbops.CreateSchema(name='edgedbpub'), dbops.CreateSchema(name='edgedbstd'), @@ -4821,7 +4831,7 @@ def _generate_namespace_views(schema: s_schema.Schema) -> List[dbops.View]: AS {qi(ptr_col_name(schema, NameSpace, '__type__'))}, (ns.description->>'name') AS {qi(ptr_col_name(schema, NameSpace, 'name'))}, - (ns.description->>'name__internal') + (ns.description->>'name') AS {qi(ptr_col_name(schema, NameSpace, 'name__internal'))}, ARRAY[]::text[] AS {qi(ptr_col_name(schema, NameSpace, 'computed_fields'))}, diff --git a/edb/schema/defines.py b/edb/schema/defines.py index 81d448d0c2d..0da8afb396e 100644 --- a/edb/schema/defines.py +++ b/edb/schema/defines.py @@ -35,3 +35,5 @@ EDGEDB_SYSTEM_DB = '__edgedbsys__' EDGEDB_SPECIAL_DBS = {EDGEDB_TEMPLATE_DB, EDGEDB_SYSTEM_DB} + +DEFAULT_NS = 'default' diff --git a/edb/schema/delta.py b/edb/schema/delta.py index 1c7f035210f..bc675d23b86 100644 --- a/edb/schema/delta.py +++ b/edb/schema/delta.py @@ -43,7 +43,7 @@ from edb.edgeql import compiler as qlcompiler from edb.edgeql import qltypes -from . import expr as s_expr +from . import expr as s_expr, defines from . import name as sn from . import objects as so from . import schema as s_schema @@ -1219,6 +1219,7 @@ def __init__( module_is_implicit: Optional[bool] = False, external_view: Optional[Mapping] = None, restoring_external: Optional[bool] = False, + namespace: str = defines.DEFAULT_NS, ) -> None: self.stack: List[CommandContextToken[Command]] = [] self._cache: Dict[Hashable, Any] = {} @@ -1250,6 +1251,7 @@ def __init__( self.external_view = external_view or immutables.Map() self.restoring_external = restoring_external self.external_objs = set() + self.namespace = namespace @property def modaliases(self) -> Mapping[Optional[str], str]: @@ -1503,6 +1505,11 @@ def compat_ver_is_before( ) -> bool: return self.compat_ver is not None and self.compat_ver < ver + def pg_schema(self, schema_name: str): + if self.namespace == defines.DEFAULT_NS: + return schema_name + return f"{self.namespace}_{schema_name}" + class ContextStack: diff --git a/edb/schema/namespace.py b/edb/schema/namespace.py index 781f0dbf9f4..7b714c5da18 100644 --- a/edb/schema/namespace.py +++ b/edb/schema/namespace.py @@ -19,8 +19,6 @@ from __future__ import annotations -import uuid - from edb import errors from edb.edgeql import ast as qlast from edb.edgeql import qltypes @@ -28,6 +26,7 @@ from . import delta as sd from . import objects as so from . import schema as s_schema +from . import defines class NameSpace( @@ -60,6 +59,13 @@ def _validate_name( f'as such names are reserved for system schemas', context=source_context, ) + if str(name) == defines.DEFAULT_NS: + source_context = self.get_attribute_source_context('name') + raise errors.SchemaDefinitionError( + f'\'{defines.DEFAULT_NS}\' is reserved as name for ' + f'default namespace, use others instead.', + context=source_context, + ) class CreateNameSpace(NameSpaceCommand, sd.CreateExternalObject[NameSpace]): @@ -76,3 +82,14 @@ def validate_create( class DeleteNameSpace(NameSpaceCommand, sd.DeleteExternalObject[NameSpace]): astnode = qlast.DropNameSpace + + def _validate_legal_command( + self, + schema: s_schema.Schema, + context: sd.CommandContext, + ) -> None: + super()._validate_legal_command(schema, context) + if self.classname.name == defines.DEFAULT_NS: + raise errors.ExecutionError( + f"namespace {self.classname.name!r} cannot be dropped" + ) diff --git a/edb/server/bootstrap.py b/edb/server/bootstrap.py index 925bb7c8c9e..c222cdd6834 100644 --- a/edb/server/bootstrap.py +++ b/edb/server/bootstrap.py @@ -1172,6 +1172,17 @@ async def _compile_sys_queries( queries['listdbs'] = sql + _, sql = compile_bootstrap_script( + compiler, + schema, + f"""SELECT ( + SELECT sys::NameSpace + ).name""", + expected_cardinality_one=False, + ) + + queries['listns'] = sql + role_query = ''' SELECT sys::Role { name, diff --git a/edb/server/compiler/compiler.py b/edb/server/compiler/compiler.py index e030c5ecc7a..0d9a708e8c9 100644 --- a/edb/server/compiler/compiler.py +++ b/edb/server/compiler/compiler.py @@ -142,6 +142,8 @@ class CompileContext: restoring_external: Optional[bool] = False # If is test mode from http testmode: Optional[bool] = False + # NameSpace for current compile + namespace: str = defines.DEFAULT_NS DEFAULT_MODULE_ALIASES_MAP = immutables.Map( @@ -387,6 +389,7 @@ def _new_delta_context(self, ctx: CompileContext): context.module = ctx.module context.external_view = ctx.external_view context.restoring_external = ctx.restoring_external + context.namespace = ctx.namespace return context def _process_delta(self, ctx: CompileContext, delta): @@ -418,12 +421,16 @@ def _process_delta(self, ctx: CompileContext, delta): isinstance(c, s_ns.NameSpaceCommand) for c in pgdelta.get_subcommands() ) + if ctx.namespace == defines.DEFAULT_NS: + ns_prefix = '' + else: + ns_prefix = ctx.namespace + '_' if db_cmd or ns_cmd: - block = pg_dbops.SQLBlock() + block = pg_dbops.SQLBlock(ns_prefix) new_be_types = new_types = frozenset() else: - block = pg_dbops.PLTopBlock() + block = pg_dbops.PLTopBlock(ns_prefix) def may_has_backend_id(_id): obj = schema.get_by_id(_id, None) @@ -458,7 +465,7 @@ def may_has_backend_id(_id): # Generate schema storage SQL (DML into schema storage tables). if schema_peristence_async: - refl_block = pg_dbops.PLTopBlock() + refl_block = pg_dbops.PLTopBlock(ns_prefix) else: refl_block = None @@ -466,23 +473,23 @@ def may_has_backend_id(_id): ctx, pgdelta, subblock, context=context, schema_persist_block=refl_block ) - + instdata_schemaname = f"{ns_prefix}edgedbinstdata" if schema_peristence_async: if debug.flags.keep_schema_persistence_history: invalid_persist_his = f"""\ - UPDATE "edgedbinstdata"."schema_persist_history" + UPDATE "{instdata_schemaname}"."schema_persist_history" SET active = false WHERE version_id = '{str(ver_id)}'::uuid;\ """ else: invalid_persist_his = f"""\ - DELETE FROM "edgedbinstdata"."schema_persist_history" + DELETE FROM "{instdata_schemaname}"."schema_persist_history" WHERE version_id = '{str(ver_id)}'::uuid;\ """ refl_block.add_command(textwrap.dedent(invalid_persist_his)) main_block_sub = block.add_block() main_block_sub.add_command(textwrap.dedent(f"""\ - INSERT INTO "edgedbinstdata"."schema_persist_history" + INSERT INTO "{instdata_schemaname}"."schema_persist_history" ("version_id", "sql") values ( '{str(ver_id)}'::uuid, {pg_common.quote_bytea_literal(refl_block.to_string().encode())} @@ -490,7 +497,7 @@ def may_has_backend_id(_id): """)) if pgdelta.std_inhview_updates: - stdview_block = pg_dbops.PLTopBlock() + stdview_block = pg_dbops.PLTopBlock(ns_prefix) pgdelta.generate_std_inhview(stdview_block) else: stdview_block = None @@ -534,10 +541,15 @@ def _compile_schema_storage_in_delta( cache = current_tx.get_cached_reflection() + if ctx.namespace == defines.DEFAULT_NS: + schema_name = 'edgedb' + else: + schema_name = f'{ctx.namespace}_edgedb' + with cache.mutate() as cache_mm: for eql, args in meta_blocks: eql_hash = hashlib.sha1(eql.encode()).hexdigest() - fname = ('edgedb', f'__rh_{eql_hash}') + fname = (schema_name, f'__rh_{eql_hash}') if eql_hash in cache_mm: argnames = cache_mm[eql_hash] @@ -609,6 +621,7 @@ def _compile_schema_storage_stmt( expected_cardinality_one=False, bootstrap_mode=ctx.bootstrap_mode, protocol_version=ctx.protocol_version, + namespace=ctx.namespace ) source = edgeql.Source.from_string(eql) @@ -770,6 +783,7 @@ def _compile_ql_query( expected_cardinality_one=ctx.expected_cardinality_one, output_format=_convert_format(ctx.output_format), backend_runtime_params=ctx.backend_runtime_params, + namespace=ctx.namespace ) if ( @@ -1082,6 +1096,11 @@ def _compile_and_apply_ddl_stmt( else: sql = (block.to_string().encode('utf-8'),) + if context.namespace == defines.DEFAULT_NS: + ns_prefix = '' + else: + ns_prefix = context.namespace + '_' + if new_types: # Inject a query returning backend OIDs for the newly # created types. @@ -1099,7 +1118,7 @@ def _compile_and_apply_ddl_stmt( "backend_id" ) FROM - edgedb."_SchemaType" + {ns_prefix}edgedb."_SchemaType" WHERE "id" = any(ARRAY[ {', '.join(new_type_ids)} @@ -1139,6 +1158,7 @@ def _compile_and_apply_ddl_stmt( drop_db = None create_db_template = None create_ns = None + drop_ns = None if isinstance(stmt, qlast.DropDatabase): drop_db = stmt.name.name elif isinstance(stmt, qlast.CreateDatabase): @@ -1146,6 +1166,8 @@ def _compile_and_apply_ddl_stmt( create_db_template = stmt.template.name if stmt.template else None elif isinstance(stmt, qlast.CreateNameSpace): create_ns = stmt.name.name + elif isinstance(stmt, qlast.DropNameSpace): + drop_ns = stmt.name.name if debug.flags.delta_execute: debug.header('Delta Script') @@ -1163,6 +1185,7 @@ def _compile_and_apply_ddl_stmt( create_db=create_db, drop_db=drop_db, create_ns=create_ns, + drop_ns=drop_ns, create_db_template=create_db_template, has_role_ddl=isinstance(stmt, qlast.RoleCommand), ddl_stmt_id=ddl_stmt_id, @@ -2074,6 +2097,7 @@ def _try_compile( unit.drop_db = comp.drop_db unit.create_db_template = comp.create_db_template unit.create_ns = comp.create_ns + unit.drop_ns = comp.drop_ns unit.has_role_ddl = comp.has_role_ddl unit.ddl_stmt_id = comp.ddl_stmt_id if comp.user_schema is not None: @@ -2420,6 +2444,7 @@ def compile_notebook( def compile( self, + namespace: str, user_schema: s_schema.Schema, global_schema: s_schema.Schema, reflection_cache: Mapping[str, Tuple[str, ...]], @@ -2440,7 +2465,7 @@ def compile( module: Optional[str] = None, external_view: Optional[Mapping] = None, restoring_external: Optional[bool] = False, - testmode: bool = False + testmode: bool = False, ) -> Tuple[dbstate.QueryUnitGroup, Optional[dbstate.CompilerConnectionState]]: @@ -2487,7 +2512,8 @@ def compile( module=module, external_view=external_view, restoring_external=restoring_external, - testmode=testmode + testmode=testmode, + namespace=namespace ) unit_group = self._compile(ctx=ctx, source=source) @@ -2505,6 +2531,7 @@ def compile( def compile_in_tx( self, state: dbstate.CompilerConnectionState, + namespace: str, txid: int, source: edgeql.Source, output_format: enums.OutputFormat, @@ -2550,7 +2577,8 @@ def compile_in_tx( module=module, in_tx=True, external_view=external_view, - restoring_external=restoring_external + restoring_external=restoring_external, + namespace=namespace ) return self._compile(ctx=ctx, source=source), ctx.state @@ -2562,6 +2590,7 @@ def describe_database_dump( database_config: immutables.Map[str, config.SettingValue], protocol_version: Tuple[int, int], ) -> DumpDescriptor: + # TODO namespace支持 schema = s_schema.ChainedSchema( self._std_schema, user_schema, @@ -2831,6 +2860,7 @@ def describe_database_restore( protocol_version: Tuple[int, int], external_view: Dict[str, str] ) -> RestoreDescriptor: + # TODO namespace支持 schema_object_ids = { ( s_name.name_from_string(name), diff --git a/edb/server/compiler/dbstate.py b/edb/server/compiler/dbstate.py index e3a4417b010..9a13ea91b5a 100644 --- a/edb/server/compiler/dbstate.py +++ b/edb/server/compiler/dbstate.py @@ -137,6 +137,7 @@ class DDLQuery(BaseQuery): create_db: Optional[str] = None create_ns: Optional[str] = None drop_db: Optional[str] = None + drop_ns: Optional[str] = None create_db_template: Optional[str] = None has_role_ddl: bool = False ddl_stmt_id: Optional[str] = None @@ -261,10 +262,14 @@ class QueryUnit: # close all inactive unused pooled connections to the template db. create_db_template: Optional[str] = None - # If non-None, contains a name of the DB that is about to be - # created/deleted. + # If non-None, contains a name of the NameSpace that is about to be + # created. create_ns: Optional[str] = None + # If non-None, contains a name of the NameSpace that is about to be + # deleted. + drop_ns: Optional[str] = None + # If non-None, the DDL statement will emit data packets marked # with the indicated ID. ddl_stmt_id: Optional[str] = None diff --git a/edb/server/compiler_pool/pool.py b/edb/server/compiler_pool/pool.py index 8262c4788de..fbb0c1dd12f 100644 --- a/edb/server/compiler_pool/pool.py +++ b/edb/server/compiler_pool/pool.py @@ -91,11 +91,12 @@ def __repr__(self): class MutationHistory: - def __init__(self, dbname: str): + def __init__(self, dbname: str, namespace: str): self._history: List[_SchemaMutation] = [] self._index: Dict[uuid.UUID, int] = {} self._cursor: Dict[uuid.UUID, int] = {} self._db = dbname + self._namespace = namespace @property def latest_ver(self): @@ -109,7 +110,7 @@ def clear(self): self._cursor.clear() def get_pickled_mutation(self, worker: BaseWorker) -> Optional[bytes]: - start = self._cursor.get(worker.get_user_schema_id(self._db)) + start = self._cursor.get(worker.get_user_schema_id(self._db, self._namespace)) if start is None: return @@ -117,13 +118,13 @@ def get_pickled_mutation(self, worker: BaseWorker) -> Optional[bytes]: mut_bytes = self._history[start].bytes if logger.isEnabledFor(logging.DEBUG): logger.debug( - f"::CPOOL:: WOKER<{worker.identifier}> | DB<{self._db}> - " + f"::CPOOL:: WOKER<{worker.identifier}> | DB<{self._db}({self._namespace})> - " f"Using stored {self._history[start]} to update." ) else: mut = s_schema.SchemaMutationLogger.merge([m.obj for m in self._history[start:]]) logger.info( - f"::CPOOL:: WOKER<{worker.identifier}> | DB<{self._db}> - " + f"::CPOOL:: WOKER<{worker.identifier}> | DB<{self._db}({self._namespace})> - " f"Using merged {_trim_uuid(mut.target)}> to update." ) mut_bytes = pickle.dumps(mut) @@ -188,11 +189,11 @@ def __init__( self._last_used = time.monotonic() self._closed = False - def get_user_schema_id(self, dbname: str) -> uuid.UUID: - if dbname not in self._dbs: + def get_user_schema_id(self, dbname: str, namespace: str) -> uuid.UUID: + if self._dbs.get(dbname, {}).get(namespace) is None: return UNKNOW_VER_ID - return self._dbs[dbname].user_schema_version + return self._dbs[dbname][namespace].user_schema_version @functools.cached_property def identifier(self): @@ -292,7 +293,7 @@ def __init__( self._std_schema = std_schema self._refl_schema = refl_schema self._schema_class_layout = schema_class_layout - self._mut_history: Dict[str, MutationHistory] = {} + self._mut_history: Dict[str, Dict[str, MutationHistory]] = {} @functools.lru_cache(maxsize=None) def _get_init_args(self): @@ -302,19 +303,23 @@ def _get_init_args(self): def _get_init_args_uncached(self): dbs: state.DatabasesState = immutables.Map() - for db in self._dbindex.iter_dbs(): - db_user_schema = db.user_schema - version_id = UNKNOW_VER_ID if db_user_schema is None else db_user_schema.version_id - dbs = dbs.set( - db.name, - state.DatabaseState( - name=db.name, - user_schema=db_user_schema, - user_schema_version=version_id, - reflection_cache=db.reflection_cache, - database_config=db.db_config, + for db_name, ns_map in self._dbindex.iter_dbs(): + namespace = immutables.Map() + for ns_name, ns_db in ns_map.items(): + db_user_schema = ns_db.user_schema + version_id = UNKNOW_VER_ID if db_user_schema is None else db_user_schema.version_id + namespace.set( + ns_name, + state.DatabaseState( + name=ns_db.name, + namespace=ns_name, + user_schema=db_user_schema, + user_schema_version=version_id, + reflection_cache=ns_db.reflection_cache, + database_config=ns_db.db_config, + ) ) - ) + dbs = dbs.set(db_name, namespace) init_args = ( dbs, @@ -337,7 +342,7 @@ async def start(self): async def stop(self): raise NotImplementedError - def collect_worker_schema_ids(self, dbname) -> List[uuid.UUID]: + def collect_worker_schema_ids(self, dbname, namespace) -> List[uuid.UUID]: raise NotImplementedError def get_template_pid(self): @@ -346,6 +351,7 @@ def get_template_pid(self): async def sync_user_schema( self, dbname, + namespace, user_schema, reflection_cache, global_schema, @@ -359,6 +365,7 @@ async def sync_user_schema( preargs, sync_state = await self._compute_compile_preargs( worker, dbname, + namespace, user_schema, global_schema, reflection_cache, @@ -369,7 +376,7 @@ async def sync_user_schema( if preargs[2] is not None: logger.debug(f"[W::{worker.identifier}] Sync user schema.") else: - if worker.get_user_schema_id(dbname) is not UNKNOW_VER_ID: + if worker.get_user_schema_id(dbname, namespace) is not UNKNOW_VER_ID: logger.warning(f"[W::{worker.identifier}] Attempt to sync user schema failed.") logger.info(f"[W::{worker.identifier}] Initialize user schema.") @@ -382,6 +389,7 @@ async def _compute_compile_preargs( self, worker, dbname, + namespace, user_schema, global_schema, reflection_cache, @@ -393,27 +401,34 @@ def sync_worker_state_cb( *, worker, dbname, + namespace, user_schema=None, global_schema=None, reflection_cache=None, database_config=None, system_config=None, ): - worker_db = worker._dbs.get(dbname) + worker_db = worker._dbs.get(dbname, {}).get(namespace) if worker_db is None: assert user_schema is not None assert reflection_cache is not None assert global_schema is not None assert database_config is not None assert system_config is not None + ns = worker._dbs.get(dbname, immutables.Map()) + ns.set( + namespace, + state.DatabaseState( + name=dbname, + namespace=namespace, + user_schema=user_schema, + user_schema_version=user_schema.version_id, + reflection_cache=reflection_cache, + database_config=database_config, + ) + ) - worker._dbs = worker._dbs.set(dbname, state.DatabaseState( - name=dbname, - user_schema=user_schema, - user_schema_version=user_schema.version_id, - reflection_cache=reflection_cache, - database_config=database_config, - )) + worker._dbs = worker._dbs.set(dbname, ns) worker._global_schema = global_schema worker._system_config = system_config else: @@ -423,23 +438,30 @@ def sync_worker_state_cb( or database_config is not None ): new_user_schema = user_schema or worker_db.user_schema - worker._dbs = worker._dbs.set(dbname, state.DatabaseState( - name=dbname, - user_schema=new_user_schema, - user_schema_version=new_user_schema.version_id, - reflection_cache=( - reflection_cache or worker_db.reflection_cache), - database_config=( - database_config if database_config is not None - else worker_db.database_config), - )) + ns = worker._dbs[dbname] + ns.set( + namespace, + state.DatabaseState( + name=dbname, + namespace=namespace, + user_schema=new_user_schema, + user_schema_version=new_user_schema.version_id, + reflection_cache=( + reflection_cache or worker_db.reflection_cache), + database_config=( + database_config if database_config is not None + else worker_db.database_config), + ) + ) + + worker._dbs = worker._dbs.set(dbname, ns) if global_schema is not None: worker._global_schema = global_schema if system_config is not None: worker._system_config = system_config - worker_db: state.DatabaseState = worker._dbs.get(dbname) + worker_db: state.DatabaseState = worker._dbs.get(dbname, {}).get(namespace) preargs = (dbname,) to_update = {} @@ -468,16 +490,16 @@ def sync_worker_state_cb( f"Initialize db <{dbname}> schema version to: [{user_schema.version_id}]" ) else: - if dbname not in self._mut_history: + if self._mut_history.get(dbname, {}).get(namespace) is None: # 当前实例初始化后未执行任何ddl,此时在其他实例发生DDL, # 触发当前实例的introspect_db,导致worker的schema版本失效, # 这种情况下,当前实例_mut_history可能不包含dbname mutation_pickled = None else: - mutation_pickled = self._mut_history[dbname].get_pickled_mutation(worker) + mutation_pickled = self._mut_history[dbname][namespace].get_pickled_mutation(worker) if mutation_pickled is None: logger.warning( - f"::CPOOL:: WOKER<{worker.identifier}> | DB<{dbname}> - " + f"::CPOOL:: WOKER<{worker.identifier}> | DB<{dbname}({namespace})> - " f"No schema mutation available. " f"Schema <{worker_db.user_schema_version}> is outdated, will issue a full update." ) @@ -525,6 +547,7 @@ def sync_worker_state_cb( sync_worker_state_cb, worker=worker, dbname=dbname, + namespace=namespace, **to_update ) else: @@ -541,6 +564,7 @@ def _release_worker(self, worker, *, put_in_front: bool = True): def append_schema_mutation( self, dbname, + namespace, mut_bytes, mutation: s_schema.SchemaMutationLogger, user_schema, @@ -549,10 +573,12 @@ def append_schema_mutation( database_config, system_config, ): - if is_fresh := (dbname not in self._mut_history): - self._mut_history[dbname] = MutationHistory(dbname) + if is_fresh := (self._mut_history.get(dbname, {}).get(namespace) is None): + ns_map = self._mut_history.get(dbname, {}) + ns_map[namespace] = MutationHistory(dbname, namespace) + self._mut_history[dbname] = ns_map - hist = self._mut_history[dbname] + hist = self._mut_history[dbname][namespace] hist.append(_SchemaMutation( base=mutation.id, target=user_schema.version_id, @@ -561,7 +587,7 @@ def append_schema_mutation( )) if not is_fresh: - usids = self.collect_worker_schema_ids(dbname) + usids = self.collect_worker_schema_ids(dbname, namespace) hist.try_trim_history(usids) if ( @@ -570,18 +596,22 @@ def append_schema_mutation( ): logger.debug(f"Schedule {n} tasks to sync worker's user schema.") for _ in range(n): - asyncio.create_task(self.sync_user_schema( - dbname, - user_schema, - reflection_cache, - global_schema, - database_config, - system_config, - )) + asyncio.create_task( + self.sync_user_schema( + dbname, + namespace, + user_schema, + reflection_cache, + global_schema, + database_config, + system_config, + ) + ) async def compile( self, dbname, + namespace, user_schema, global_schema, reflection_cache, @@ -594,6 +624,7 @@ async def compile( preargs, sync_state = await self._compute_compile_preargs( worker, dbname, + namespace, user_schema, global_schema, reflection_cache, @@ -619,6 +650,7 @@ async def compile( async def compile_in_tx( self, dbname, + namespace, txid, pickled_state, state_id, @@ -658,7 +690,7 @@ async def compile_in_tx( pickled_state = state.REUSE_LAST_STATE_MARKER user_schema = None else: - usid = worker.get_user_schema_id(dbname) + usid = worker.get_user_schema_id(dbname, namespace) if state_id == 0: if base_user_schema.version_id != usid: user_schema = _pickle_memoized(base_user_schema) @@ -682,6 +714,7 @@ async def compile_in_tx( 'compile_in_tx', pickled_state, dbname, + namespace, user_schema, txid, *compile_args @@ -700,6 +733,7 @@ async def compile_notebook( self, dbname, user_schema, + namespace, global_schema, reflection_cache, database_config, @@ -711,6 +745,7 @@ async def compile_notebook( preargs, sync_state = await self._compute_compile_preargs( worker, dbname, + namespace, user_schema, global_schema, reflection_cache, @@ -746,6 +781,7 @@ async def try_compile_rollback( async def compile_graphql( self, dbname, + namespace, user_schema, global_schema, reflection_cache, @@ -758,6 +794,7 @@ async def compile_graphql( preargs, sync_state = await self._compute_compile_preargs( worker, dbname, + namespace, user_schema, global_schema, reflection_cache, @@ -778,6 +815,7 @@ async def compile_graphql( async def infer_expr( self, dbname, + namespace, user_schema, global_schema, reflection_cache, @@ -790,6 +828,7 @@ async def infer_expr( preargs, sync_state = await self._compute_compile_preargs( worker, dbname, + namespace, user_schema, global_schema, reflection_cache, @@ -1033,8 +1072,8 @@ def _release_worker(self, worker, *, put_in_front: bool = True): if worker.get_pid() in self._workers: self._workers_queue.release(worker, put_in_front=put_in_front) - def collect_worker_schema_ids(self, dbname) -> Iterable[uuid.UUID]: - return [w.get_user_schema_id(dbname) for w in self._workers.values()] + def collect_worker_schema_ids(self, dbname, namespace) -> Iterable[uuid.UUID]: + return [w.get_user_schema_id(dbname, namespace) for w in self._workers.values()] @srvargs.CompilerPoolMode.Fixed.assign_implementation @@ -1126,8 +1165,8 @@ class DebugWorker: _last_pickled_state = None connected = False - def get_user_schema_id(self, dbname): - return BaseWorker.get_user_schema_id(self, dbname) # noqa + def get_user_schema_id(self, dbname, namespace): + return BaseWorker.get_user_schema_id(self, dbname, namespace) # noqa async def call(self, method_name, *args, sync_state=None): from . import worker @@ -1155,19 +1194,23 @@ def __init__(self, **kwargs): @functools.lru_cache(maxsize=None) def _get_init_args(self): dbs: state.DatabasesState = immutables.Map() - for db in self._dbindex.iter_dbs(): - db_user_schema = db.user_schema - version_id = UNKNOW_VER_ID if db_user_schema is None else db_user_schema.version_id - dbs = dbs.set( - db.name, - state.DatabaseState( - name=db.name, - user_schema=db_user_schema, - user_schema_version=version_id, - reflection_cache=db.reflection_cache, - database_config=db.db_config, + for db_name, ns_map in self._dbindex.iter_dbs(): + namespace = immutables.Map() + for ns_name, ns_db in ns_map.items(): + db_user_schema = ns_db.user_schema + version_id = UNKNOW_VER_ID if db_user_schema is None else db_user_schema.version_id + namespace.set( + ns_name, + state.DatabaseState( + name=db_name, + namespace=ns_name, + user_schema=db_user_schema, + user_schema_version=version_id, + reflection_cache=ns_db.reflection_cache, + database_config=ns_db.db_config, + ) ) - ) + dbs = dbs.set(db_name, namespace) self._worker._dbs = dbs self._worker._backend_runtime_params = self._backend_runtime_params self._worker._std_schema = self._std_schema @@ -1228,8 +1271,8 @@ async def stop(self): if self._worker.connected: self.worker_disconnected(os.getpid()) - def collect_worker_schema_ids(self, dbname) -> Iterable[uuid.UUID]: - return [self._worker.get_user_schema_id(dbname)] + def collect_worker_schema_ids(self, dbname, namespace) -> Iterable[uuid.UUID]: + return [self._worker.get_user_schema_id(dbname, namespace)] @srvargs.CompilerPoolMode.OnDemand.assign_implementation @@ -1520,7 +1563,7 @@ async def _compute_compile_preargs(self, *args): self._sync_lock.release() return preargs, callback - def collect_worker_schema_ids(self, dbname) -> Iterable[uuid.UUID]: + def collect_worker_schema_ids(self, dbname, namespace) -> Iterable[uuid.UUID]: return [] diff --git a/edb/server/compiler_pool/state.py b/edb/server/compiler_pool/state.py index f5444cd6fad..3b7795229bc 100644 --- a/edb/server/compiler_pool/state.py +++ b/edb/server/compiler_pool/state.py @@ -31,13 +31,14 @@ class DatabaseState(typing.NamedTuple): name: str + namespace: str user_schema: typing.Optional[schema.FlatSchema] user_schema_version: typing.Optional[uuid.UUID] reflection_cache: ReflectionCache database_config: immutables.Map[str, config.SettingValue] -DatabasesState = immutables.Map[str, DatabaseState] +DatabasesState = immutables.Map[str, immutables.Map[str, DatabaseState]] class FailedStateSync(Exception): diff --git a/edb/server/compiler_pool/worker.py b/edb/server/compiler_pool/worker.py index add80ad95a7..a373f8f2ee6 100644 --- a/edb/server/compiler_pool/worker.py +++ b/edb/server/compiler_pool/worker.py @@ -109,6 +109,7 @@ def __init_worker__( def __sync__( dbname: str, + namespace: str, user_schema: Optional[bytes], user_schema_mutation: Optional[bytes], reflection_cache: Optional[bytes], @@ -122,7 +123,7 @@ def __sync__( global INSTANCE_CONFIG try: - db = DBS.get(dbname) + db = DBS.get(dbname, {}).get(namespace) if db is None: assert user_schema is not None assert reflection_cache is not None @@ -130,14 +131,19 @@ def __sync__( user_schema_unpacked = pickle.loads(user_schema) reflection_cache_unpacked = pickle.loads(reflection_cache) database_config_unpacked = pickle.loads(database_config) - db = state.DatabaseState( - dbname, - user_schema_unpacked, - user_schema_unpacked.version_id, - reflection_cache_unpacked, - database_config_unpacked, + ns = DBS.get(dbname, immutables.Map()) + ns.set( + namespace, + state.DatabaseState( + dbname, + namespace, + user_schema_unpacked, + user_schema_unpacked.version_id, + reflection_cache_unpacked, + database_config_unpacked, + ) ) - DBS = DBS.set(dbname, db) + DBS = DBS.set(dbname, ns) else: updates = {} @@ -157,7 +163,9 @@ def __sync__( if updates: db = db._replace(**updates) - DBS = DBS.set(dbname, db) + ns = DBS[dbname] + ns.set(namespace, db) + DBS = DBS.set(dbname, ns) if global_schema is not None: GLOBAL_SCHEMA = pickle.loads(global_schema) @@ -175,6 +183,7 @@ def __sync__( def compile( dbname: str, + namespace: str, user_schema: Optional[bytes], user_schema_mutation: Optional[bytes], reflection_cache: Optional[bytes], @@ -187,6 +196,7 @@ def compile( with util.disable_gc(): db = __sync__( dbname, + namespace, user_schema, user_schema_mutation, reflection_cache, @@ -196,6 +206,7 @@ def compile( ) units, cstate = COMPILER.compile( + namespace, db.user_schema, GLOBAL_SCHEMA, db.reflection_cache, @@ -214,7 +225,7 @@ def compile( return units, pickled_state -def compile_in_tx(cstate, dbname, user_schema_pickled, *args, **kwargs): +def compile_in_tx(cstate, dbname, namespace, user_schema_pickled, *args, **kwargs): global LAST_STATE global DBS @@ -227,17 +238,18 @@ def compile_in_tx(cstate, dbname, user_schema_pickled, *args, **kwargs): if user_schema_pickled is not None: user_schema: s_schema.FlatSchema = pickle.loads(user_schema_pickled) else: - user_schema = DBS.get(dbname).user_schema + user_schema = DBS.get(dbname).get(namespace).user_schema cstate = cstate.restore(user_schema) - units, cstate = COMPILER.compile_in_tx(cstate, *args, **kwargs) + units, cstate = COMPILER.compile_in_tx(cstate, namespace, *args, **kwargs) LAST_STATE = cstate return units, pickle.dumps(cstate.compress(), -1), cstate.base_user_schema_id def compile_notebook( dbname: str, + namespace: str, user_schema: Optional[bytes], user_schema_mutation: Optional[bytes], reflection_cache: Optional[bytes], @@ -249,6 +261,7 @@ def compile_notebook( ): db = __sync__( dbname, + namespace, user_schema, user_schema_mutation, reflection_cache, @@ -270,6 +283,7 @@ def compile_notebook( def infer_expr( dbname: str, + namespace: str, user_schema: Optional[bytes], user_schema_mutation: Optional[bytes], reflection_cache: Optional[bytes], @@ -281,6 +295,7 @@ def infer_expr( ): db = __sync__( dbname, + namespace, user_schema, user_schema_mutation, reflection_cache, @@ -321,6 +336,7 @@ def describe_database_restore( def compile_graphql( dbname: str, + namespace: str, user_schema: Optional[bytes], user_schema_mutation: Optional[bytes], reflection_cache: Optional[bytes], @@ -340,6 +356,7 @@ def compile_graphql( db = __sync__( dbname, + namespace, user_schema, user_schema_mutation, reflection_cache, @@ -350,6 +367,7 @@ def compile_graphql( gql_op = graphql.compile_graphql( dbname, + namespace, STD_SCHEMA, db.user_schema, GLOBAL_SCHEMA, @@ -364,6 +382,7 @@ def compile_graphql( ) unit_group, _ = COMPILER.compile( + namespace=namespace, user_schema=db.user_schema, global_schema=GLOBAL_SCHEMA, reflection_cache=db.reflection_cache, diff --git a/edb/server/dbview/dbview.pxd b/edb/server/dbview/dbview.pxd index 0ccb3f2bd04..5823dc5cae0 100644 --- a/edb/server/dbview/dbview.pxd +++ b/edb/server/dbview/dbview.pxd @@ -84,6 +84,7 @@ cdef class Database: bint _log_cache readonly str name + readonly str namespace readonly object dbver readonly object db_config readonly object user_schema diff --git a/edb/server/dbview/dbview.pyx b/edb/server/dbview/dbview.pyx index 46c1ae54463..f75f51efc7b 100644 --- a/edb/server/dbview/dbview.pyx +++ b/edb/server/dbview/dbview.pyx @@ -290,6 +290,7 @@ cdef class Database: self, DatabaseIndex index, str name, + str namespace, *, object user_schema, object db_config, @@ -298,6 +299,7 @@ cdef class Database: object extensions, ): self.name = name + self.namespace = namespace self.dbver = next_dbver() @@ -500,11 +502,11 @@ cdef class Database: if self.user_schema is None: async with self._introspection_lock: if self.user_schema is None: - await self._index._server.introspect_db(self.name) + await self._index._server.introspect_db(self.name, self.namespace) async def persist_schema(self): async with self._introspection_lock: - await self._index._server.persist_user_schema(self.name) + await self._index._server.persist_user_schema(self.name, self.namespace) def schedule_schema_persistence(self): asyncio.create_task(self.persist_schema()) @@ -948,6 +950,10 @@ cdef class DatabaseConnectionView: def __get__(self): return self._db.name + property namespace: + def __get__(self): + return self._db.namespace + property reflection_cache: def __get__(self): return self._db.reflection_cache @@ -1092,6 +1098,7 @@ cdef class DatabaseConnectionView: def save_schema_mutation(self, mut, mut_bytes): self._db._index._server.get_compiler_pool().append_schema_mutation( self.dbname, + self.namespace, mut_bytes, mut, self.get_user_schema(), @@ -1369,6 +1376,7 @@ cdef class DatabaseConnectionView: if self.in_tx(): result = await compiler_pool.compile_in_tx( self.dbname, + self.namespace, self.txid, self._last_comp_state, self._last_comp_state_id, @@ -1390,6 +1398,7 @@ cdef class DatabaseConnectionView: else: result = await compiler_pool.compile( self.dbname, + self.namespace, self.get_user_schema(), self.get_global_schema(), self.reflection_cache, @@ -1455,7 +1464,7 @@ cdef class DatabaseIndex: except KeyError: return 0 - return len((db)._views) + return sum(len(ns._views) for ns in db.values()) def get_sys_config(self): return self._sys_config @@ -1470,15 +1479,15 @@ cdef class DatabaseIndex: def has_db(self, dbname): return dbname in self._dbs - def get_db(self, dbname): + def get_db(self, dbname, namespace): try: - return self._dbs[dbname] + return self._dbs[dbname][namespace] except KeyError: raise errors.UnknownDatabaseError( - f'database {dbname!r} does not exist') + f'database {dbname!r} (namespace: {namespace}) does not exist') - def maybe_get_db(self, dbname): - return self._dbs.get(dbname) + def maybe_get_db(self, dbname, namespace): + return self._dbs.get(dbname, {}).get(namespace) def get_global_schema(self): return self._global_schema @@ -1486,9 +1495,10 @@ cdef class DatabaseIndex: def update_global_schema(self, global_schema): self._global_schema = global_schema - def register_db( + def register_ns( self, dbname, + namespace, *, user_schema, db_config, @@ -1497,27 +1507,34 @@ cdef class DatabaseIndex: extensions=None, ): cdef Database db - db = self._dbs.get(dbname) + db = self._dbs.get(dbname, {}).get(namespace) if db is not None: db._set_and_signal_new_user_schema( - user_schema, reflection_cache, backend_ids, db_config) + user_schema, reflection_cache, backend_ids, db_config + ) else: db = Database( self, dbname, + namespace=namespace, user_schema=user_schema, db_config=db_config, reflection_cache=reflection_cache, backend_ids=backend_ids, extensions=extensions, ) - self._dbs[dbname] = db + ns_map = self._dbs.get(dbname, {}) + ns_map[namespace] = db + self._dbs[dbname] = ns_map + + def unregister_ns(self, dbname, namespace): + self._dbs.get(dbname, {}).pop(namespace) def unregister_db(self, dbname): self._dbs.pop(dbname) def iter_dbs(self): - return iter(self._dbs.values()) + return iter(self._dbs.items()) async def _save_system_overrides(self, conn): data = config.to_json( @@ -1584,10 +1601,10 @@ cdef class DatabaseIndex: await self._server._after_system_config_reset( op.setting_name) - def new_view(self, dbname: str, *, query_cache: bool, protocol_version): - db = self.get_db(dbname) + def new_view(self, dbname: str, namespace: str, *, query_cache: bool, protocol_version): + db = self.get_db(dbname, namespace) return (db)._new_view(query_cache, protocol_version) def remove_view(self, view: DatabaseConnectionView): - db = self.get_db(view.dbname) + db = self.get_db(view.dbname, view.namespace) return (db)._remove_view(view) diff --git a/edb/server/defines.py b/edb/server/defines.py index 9877a5fa4cb..96f4cfbbc90 100644 --- a/edb/server/defines.py +++ b/edb/server/defines.py @@ -28,6 +28,7 @@ EDGEDB_SUPERGROUP = 'edgedb_supergroup' EDGEDB_SUPERUSER = s_def.EDGEDB_SUPERUSER EDGEDB_TEMPLATE_DB = s_def.EDGEDB_TEMPLATE_DB +DEFAULT_NS = s_def.DEFAULT_NS EDGEDB_SUPERUSER_DB = 'edgedb' EDGEDB_SYSTEM_DB = s_def.EDGEDB_SYSTEM_DB EDGEDB_ENCODING = 'utf-8' diff --git a/edb/server/pgcon/pgcon.pyx b/edb/server/pgcon/pgcon.pyx index bf1187da24a..96bc6f12410 100644 --- a/edb/server/pgcon/pgcon.pyx +++ b/edb/server/pgcon/pgcon.pyx @@ -1939,14 +1939,16 @@ cdef class PGConnection: event_payload = event_data.get('args') if event == 'schema-changes': dbname = event_payload['dbname'] - self.server._on_remote_ddl(dbname) + namespace = event_payload['namespace'] + self.server._on_remote_ddl(dbname, namespace) elif event == 'database-config-changes': dbname = event_payload['dbname'] self.server._on_remote_database_config_change(dbname) elif event == 'system-config-changes': self.server._on_remote_system_config_change() elif event == 'global-schema-changes': - self.server._on_global_schema_change() + namespace = event_payload['namespace'] + self.server._on_global_schema_change(namespace) else: raise AssertionError(f'unexpected system event: {event!r}') diff --git a/edb/server/protocol/binary.pyx b/edb/server/protocol/binary.pyx index e3f71639628..1a37cc48b77 100644 --- a/edb/server/protocol/binary.pyx +++ b/edb/server/protocol/binary.pyx @@ -117,6 +117,7 @@ DEF QUERY_HEADER_ALLOW_CAPABILITIES = 0xFF04 DEF QUERY_HEADER_EXPLICIT_OBJECTIDS = 0xFF05 DEF QUERY_HEADER_EXPLICIT_MODULE = 0xFF06 DEF QUERY_HEADER_PROHIBIT_MUTATION = 0xFF07 +DEF QUERY_HEADER_EXPLICIT_NS = 0xFF08 DEF SERVER_HEADER_CAPABILITIES = 0x1001 diff --git a/edb/server/protocol/binary_v0.pyx b/edb/server/protocol/binary_v0.pyx index 43f6c0d1d88..f8bd843c5ed 100644 --- a/edb/server/protocol/binary_v0.pyx +++ b/edb/server/protocol/binary_v0.pyx @@ -282,6 +282,7 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): compiler_pool = server.get_compiler_pool() dbname = _dbview.dbname + namespace = _dbview.namespace pgcon = await server.acquire_pgcon(dbname) self._in_dump_restore = True try: @@ -312,8 +313,8 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): ''', ) - user_schema = await server.introspect_user_schema(dbname, pgcon) - global_schema = await server.introspect_global_schema(pgcon) + user_schema = await server.introspect_user_schema(dbname, namespace, pgcon) + global_schema = await server.introspect_global_schema(namespace, pgcon) db_config = await server.introspect_db_config(pgcon) dump_protocol = self.max_protocol @@ -515,6 +516,7 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): self.buffer.finish_message() dbname = _dbview.dbname + namespace = _dbview.namespace pgcon = await server.acquire_pgcon(dbname) self._in_dump_restore = True @@ -659,7 +661,7 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): self._in_dump_restore = False server.release_pgcon(dbname, pgcon) - await server.introspect_db(dbname) + await server.introspect_db(dbname, namespace) msg = WriteBuffer.new_message(b'C') msg.write_int16(0) # no headers @@ -1238,12 +1240,13 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): query_unit.create_db_template, _dbview.dbname ) if query_unit.drop_db: - await self.server._on_before_drop_db( - query_unit.drop_db, _dbview.dbname) - + await self.server._on_before_drop_db(query_unit.drop_db, _dbview.dbname) + if query_unit.create_ns: + await self.server.create_namespace(conn, query_unit.create_ns) + if query_unit.drop_ns: + await self.server._on_before_drop_ns(query_unit.drop_ns, _dbview.namespace) if query_unit.system_config: - await execute.execute_system_config( - conn, _dbview, query_unit) + await execute.execute_system_config(conn, _dbview, query_unit) else: if query_unit.sql: if query_unit.ddl_stmt_id: @@ -1251,10 +1254,7 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): if ddl_ret and ddl_ret['new_types']: new_types = ddl_ret['new_types'] elif query_unit.is_transactional: - await conn.sql_execute( - query_unit.sql, - state=state, - ) + await conn.sql_execute(query_unit.sql, state=state) else: i = 0 for sql in query_unit.sql: @@ -1270,18 +1270,19 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): orig_state = None if query_unit.create_db: - await self.server.introspect_db( - query_unit.create_db - ) + await self.server.introspect_db(query_unit.create_db) + + if query_unit.create_ns: + await self.server.introspect_db(_dbview.dbname, query_unit.create_ns) + + if query_unit.drop_db: + self.server._on_after_drop_db(query_unit.drop_db) if query_unit.drop_db: - self.server._on_after_drop_db( - query_unit.drop_db) + self.server._on_after_drop_ns(_dbview.dbname, query_unit.drop_ns) if query_unit.config_ops: - await _dbview.apply_config_ops( - conn, - query_unit.config_ops) + await _dbview.apply_config_ops(conn, query_unit.config_ops) except Exception: _dbview.on_error() if not conn.in_tx() and _dbview.in_tx(): diff --git a/edb/server/protocol/execute.pyx b/edb/server/protocol/execute.pyx index 46c72feff2d..e418abe0ddb 100644 --- a/edb/server/protocol/execute.pyx +++ b/edb/server/protocol/execute.pyx @@ -84,6 +84,8 @@ async def execute( await server._on_before_drop_db(query_unit.drop_db, dbv.dbname) if query_unit.create_ns: await server.create_namespace(be_conn, query_unit.create_ns) + if query_unit.drop_ns: + await server._on_before_drop_ns(query_unit.drop_ns, dbv.namespace) if query_unit.system_config: await execute_system_config(be_conn, dbv, query_unit) else: @@ -134,9 +136,15 @@ async def execute( if query_unit.create_db: await server.introspect_db(query_unit.create_db) + if query_unit.create_ns: + await server.introspect_db(dbv.dbname, query_unit.create_ns) + if query_unit.drop_db: server._on_after_drop_db(query_unit.drop_db) + if query_unit.drop_ns: + server._on_after_drop_ns(dbv.dbname, query_unit.drop_ns) + if config_ops: await dbv.apply_config_ops(be_conn, config_ops) @@ -380,6 +388,7 @@ def signal_side_effects(dbv, side_effects): server._signal_sysevent( 'schema-changes', dbname=dbv.dbname, + namespace=dbv.namespace, ), interruptable=False, ) @@ -388,6 +397,7 @@ def signal_side_effects(dbv, side_effects): server.create_task( server._signal_sysevent( 'global-schema-changes', + namespace=dbv.namespace, ), interruptable=False, ) @@ -422,6 +432,7 @@ async def parse_execute( dbname=db.name, query_cache=False, protocol_version=edbdef.CURRENT_PROTOCOL, + namespace=db.namespace ) query_req = dbview.QueryRequestInfo( @@ -468,6 +479,7 @@ async def parse_execute_json( dbname=db.name, query_cache=query_cache_enabled, protocol_version=edbdef.CURRENT_PROTOCOL, + namespace=db.namespace ) allow_cap = compiler.Capability(0) if read_only else compiler.Capability.MODIFICATIONS diff --git a/edb/server/protocol/notebook_ext.pyx b/edb/server/protocol/notebook_ext.pyx index 259cb42edc8..1cdccdad95a 100644 --- a/edb/server/protocol/notebook_ext.pyx +++ b/edb/server/protocol/notebook_ext.pyx @@ -176,6 +176,7 @@ async def execute(db, server, queries: list): dbname=db.name, query_cache=False, protocol_version=edbdef.CURRENT_PROTOCOL, + namespace=db.namespace ) bind_data = None diff --git a/edb/server/server.py b/edb/server/server.py index 8f22630e88e..230079110d2 100644 --- a/edb/server/server.py +++ b/edb/server/server.py @@ -78,6 +78,7 @@ ADMIN_PLACEHOLDER = "" logger = logging.getLogger('edb.server') log_metrics = logging.getLogger('edb.server.metrics') +_RE_BYTES_REPL_NS = re.compile(rb'(edgedb)(\.|instdata|pub|ss|std|;)', flags=re.MULTILINE) class RoleDescriptor(TypedDict): @@ -275,11 +276,11 @@ def __init__( self._ns_tpl_sql = None @contextlib.asynccontextmanager - async def aquire_distributed_lock(self, dbname, conn): + async def aquire_distributed_lock(self, dbname, namespace, conn): try: - logger.debug(f'Aquiring advisory lock for <{dbname}>') + logger.debug(f'Aquiring advisory lock for <{dbname}({namespace})>') await conn.sql_execute('select pg_advisory_lock(202304241756)'.encode()) - logger.debug(f'Advisory lock for <{dbname}> aquired') + logger.debug(f'Advisory lock for <{dbname}({namespace})> aquired') yield finally: await conn.sql_execute('select pg_advisory_unlock(202304241756)'.encode()) @@ -416,7 +417,7 @@ async def init(self): await self._load_instance_data() - global_schema = await self.introspect_global_schema() + global_schema = await self.introspect_global_schema(defines.DEFAULT_NS) sys_config = await self.load_sys_config() await self.load_reported_config() @@ -540,19 +541,19 @@ def get_compiler_pool(self): def get_suggested_client_pool_size(self) -> int: return self._suggested_client_pool_size - def get_db(self, *, dbname: str): + def get_db(self, *, dbname: str, namespace: str = defines.DEFAULT_NS): assert self._dbindex is not None - return self._dbindex.get_db(dbname) + return self._dbindex.get_db(dbname, namespace) - def maybe_get_db(self, *, dbname: str): + def maybe_get_db(self, *, dbname: str, namespace: str = defines.DEFAULT_NS): assert self._dbindex is not None - return self._dbindex.maybe_get_db(dbname) + return self._dbindex.maybe_get_db(dbname, namespace) - async def new_dbview(self, *, dbname, query_cache, protocol_version): - db = self.get_db(dbname=dbname) + async def new_dbview(self, *, dbname, query_cache, protocol_version, namespace: str = defines.DEFAULT_NS): + db = self.get_db(dbname=dbname, namespace=namespace) await db.introspection() return self._dbindex.new_view( - dbname, query_cache=query_cache, protocol_version=protocol_version + dbname, namespace=namespace, query_cache=query_cache, protocol_version=protocol_version ) def remove_dbview(self, dbview): @@ -639,8 +640,11 @@ async def load_reported_config(self): finally: self._release_sys_pgcon() - async def introspect_global_schema(self, conn=None): - intro_query = self._global_intro_query + async def introspect_global_schema(self, namespace, conn=None): + intro_query = _RE_BYTES_REPL_NS.sub( + namespace.encode('utf-8') + rb'_\1\2', + self._global_intro_query + ) if conn is not None: json_data = await conn.sql_fetch_val(intro_query) else: @@ -657,21 +661,25 @@ async def introspect_global_schema(self, conn=None): schema_class_layout=self._schema_class_layout, ) - async def _reintrospect_global_schema(self): + async def _reintrospect_global_schema(self, namespace): if not self._initing and not self._serving: logger.warning( "global-schema-changes event received during shutdown; " "ignoring." ) return - new_global_schema = await self.introspect_global_schema() + new_global_schema = await self.introspect_global_schema(namespace) self._dbindex.update_global_schema(new_global_schema) self._fetch_roles() - async def introspect_user_schema(self, dbname, conn): - await self._persist_user_schema(dbname, conn) + async def introspect_user_schema(self, dbname, namespace, conn): + await self._persist_user_schema(dbname, namespace, conn) - json_data = await conn.sql_fetch_val(self._local_intro_query) + ns_intro_query = _RE_BYTES_REPL_NS.sub( + namespace.encode('utf-8') + rb'_\1\2', + self._local_intro_query + ) + json_data = await conn.sql_fetch_val(ns_intro_query) base_schema = s_schema.ChainedSchema( self._std_schema, @@ -704,7 +712,7 @@ async def _acquire_intro_pgcon(self, dbname): raise return conn - async def introspect_db(self, dbname): + async def introspect_db(self, dbname, namespace: str = None): """Use this method to (re-)introspect a DB. If the DB is already registered in self._dbindex, its @@ -726,64 +734,81 @@ async def introspect_db(self, dbname): return try: - user_schema = await self.introspect_user_schema(dbname, conn) - - reflection_cache_json = await conn.sql_fetch_val( - b''' - SELECT json_agg(o.c) - FROM ( - SELECT - json_build_object( - 'eql_hash', t.eql_hash, - 'argnames', array_to_json(t.argnames) - ) AS c - FROM - ROWS FROM(edgedb._get_cached_reflection()) - AS t(eql_hash text, argnames text[]) - ) AS o; - ''', - ) - - reflection_cache = immutables.Map({ - r['eql_hash']: tuple(r['argnames']) - for r in json.loads(reflection_cache_json) - }) - - backend_ids_json = await conn.sql_fetch_val( - b''' - SELECT - json_object_agg( - "id"::text, - "backend_id" - )::text - FROM - edgedb."_SchemaType" - ''', - ) - backend_ids = json.loads(backend_ids_json) - - db_config = await self.introspect_db_config(conn) + if namespace is None: + ns_query = self.get_sys_query('listns') + json_data = await conn.sql_fetch_val(ns_query) + ns_list = json.loads(json_data) + else: + ns_list = [namespace] - assert self._dbindex is not None - self._dbindex.register_db( - dbname, - user_schema=user_schema, - db_config=db_config, - reflection_cache=reflection_cache, - backend_ids=backend_ids, - ) + for ns in ns_list: + await self.introspect_ns(conn, dbname, ns) finally: self.release_pgcon(dbname, conn) - async def _persist_user_schema(self, dbname, conn): - async with self.aquire_distributed_lock(dbname, conn): + async def introspect_ns(self, conn, dbname, namespace): + user_schema = await self.introspect_user_schema(dbname, namespace, conn) + if namespace == defines.DEFAULT_NS: + schema_name = 'edgedb' + else: + schema_name = f"{namespace}_edgedb" + reflection_cache_json = await conn.sql_fetch_val( + f''' + SELECT json_agg(o.c) + FROM ( + SELECT + json_build_object( + 'eql_hash', t.eql_hash, + 'argnames', array_to_json(t.argnames) + ) AS c + FROM + ROWS FROM({schema_name}._get_cached_reflection()) + AS t(eql_hash text, argnames text[]) + ) AS o; + '''.encode('utf-8'), + ) + reflection_cache = immutables.Map( + { + r['eql_hash']: tuple(r['argnames']) + for r in json.loads(reflection_cache_json) + } + ) + backend_ids_json = await conn.sql_fetch_val( + f''' + SELECT + json_object_agg( + "id"::text, + "backend_id" + )::text + FROM + {schema_name}."_SchemaType" + '''.encode('utf-8'), + ) + backend_ids = json.loads(backend_ids_json) + db_config = await self.introspect_db_config(conn) + assert self._dbindex is not None + self._dbindex.register_ns( + dbname, + namespace=namespace, + user_schema=user_schema, + db_config=db_config, + reflection_cache=reflection_cache, + backend_ids=backend_ids, + ) + + async def _persist_user_schema(self, dbname, namespace, conn): + if namespace == defines.DEFAULT_NS: + schema_name = 'edgedbinstdata' + else: + schema_name = f"{namespace}_edgedbinstdata" + async with self.aquire_distributed_lock(dbname, namespace, conn): persist_sqls = await conn.sql_fetch( - b'''\ + f'''\ SELECT "version_id", convert_from("sql", 'utf8') from - edgedbinstdata.schema_persist_history + {schema_name}.schema_persist_history WHERE active ORDER BY "timestamp" - ''' + '''.encode('utf-8') ) if not persist_sqls: logger.debug(f"No schema persistence to do.") @@ -791,15 +816,15 @@ async def _persist_user_schema(self, dbname, conn): for vid, sql in persist_sqls: await conn.sql_execute(sql) - logger.debug(f"Finish schema persistence for <{dbname}: {uuid.UUID(bytes=vid)}>") + logger.debug(f"Finish schema persistence for <{dbname}({namespace}): {uuid.UUID(bytes=vid)}>") - async def persist_user_schema(self, dbname): + async def persist_user_schema(self, dbname, namespace): conn = await self._acquire_intro_pgcon(dbname) if not conn: return try: - await self._persist_user_schema(dbname, conn) + await self._persist_user_schema(dbname, namespace, conn) finally: self.release_pgcon(dbname, conn) @@ -823,7 +848,7 @@ async def introspect_db_config(self, conn): async def _early_introspect_db(self, dbname): """We need to always introspect the extensions for each database. - Otherwise we won't know to accept connections for graphql or + Otherwise, we won't know to accept connections for graphql or http, for example, until a native connection is made. """ logger.info("introspecting extensions for database '%s'", dbname) @@ -833,25 +858,34 @@ async def _early_introspect_db(self, dbname): return try: - extension_names_json = await conn.sql_fetch_val( - b''' - SELECT json_agg(name) FROM edgedb."_SchemaExtension"; - ''', - ) - if extension_names_json: - extensions = set(json.loads(extension_names_json)) - else: - extensions = set() + ns_query = self.get_sys_query('listns') + json_data = await conn.sql_fetch_val(ns_query) + ns_list = json.loads(json_data) + for ns in ns_list: + if ns == defines.DEFAULT_NS: + schema_name = 'edgedb' + else: + schema_name = f"{ns}_edgedb" + extension_names_json = await conn.sql_fetch_val( + f''' + SELECT json_agg(name) FROM {schema_name}."_SchemaExtension"; + '''.encode('utf-8') + ) + if extension_names_json: + extensions = set(json.loads(extension_names_json)) + else: + extensions = set() - assert self._dbindex is not None - self._dbindex.register_db( - dbname, - user_schema=None, - db_config=None, - reflection_cache=None, - backend_ids=None, - extensions=extensions, - ) + assert self._dbindex is not None + self._dbindex.register_ns( + dbname, + namespace=ns, + user_schema=None, + db_config=None, + reflection_cache=None, + backend_ids=None, + extensions=extensions, + ) finally: self.release_pgcon(dbname, conn) @@ -1066,6 +1100,16 @@ async def _on_before_drop_db( await self._ensure_database_not_connected(dbname) + async def _on_before_drop_ns( + self, + namespace: str, + current_namespace: str + ) -> None: + if current_namespace == namespace: + raise errors.ExecutionError( + f'cannot drop the currently open current_namespace {namespace!r}' + ) + async def _on_before_create_db_from_template( self, dbname: str, @@ -1099,6 +1143,10 @@ def _on_after_drop_db(self, dbname: str): metrics.background_errors.inc(1.0, 'on_after_drop_db') raise + def _on_after_drop_ns(self, dbname: str, namespace: str): + assert self._dbindex is not None + self._dbindex.unregister_ns(dbname, namespace) + async def _on_system_config_add(self, setting_name, value): # CONFIGURE INSTANCE INSERT ConfigObject; pass @@ -1248,7 +1296,7 @@ async def _signal_sysevent(self, event, **kwargs): metrics.background_errors.inc(1.0, 'signal_sysevent') raise - def _on_remote_ddl(self, dbname): + def _on_remote_ddl(self, dbname, namespace): if not self._accept_new_tasks: return @@ -1256,7 +1304,7 @@ def _on_remote_ddl(self, dbname): # on the __edgedb_sysevent__ channel async def task(): try: - await self.introspect_db(dbname) + await self.introspect_db(dbname, namespace) except Exception: metrics.background_errors.inc(1.0, 'on_remote_ddl') raise @@ -1313,19 +1361,19 @@ async def task(): self.create_task(task(), interruptable=True) - def _on_global_schema_change(self): + def _on_global_schema_change(self, namespace): if not self._accept_new_tasks: return - async def task(): + async def task(ns): try: - await self._reintrospect_global_schema() + await self._reintrospect_global_schema(ns) except Exception: metrics.background_errors.inc( 1.0, 'on_global_schema_change') raise - self.create_task(task(), interruptable=True) + self.create_task(task(namespace), interruptable=True) def _on_sys_pgcon_connection_lost(self, exc): try: @@ -1903,11 +1951,9 @@ def on_switch_over(self): ) async def create_namespace(self, be_conn: pgcon.PGConnection, name: str): - tpl_sql = re.sub( - rb'(edgedb)(\.|instdata|pub|ss|std|;)', + tpl_sql = _RE_BYTES_REPL_NS.sub( name.encode('utf-8') + rb'_\1\2', self._ns_tpl_sql, - flags=re.MULTILINE, ) tpl_sql = re.sub( rb'({ns_edgedbext})', @@ -1953,26 +1999,31 @@ def serialize_config(cfg): ) dbs = {} - for db in self._dbindex.iter_dbs(): - if db.name in defines.EDGEDB_SPECIAL_DBS: + for db_name, ns_map in self._dbindex.iter_dbs(): + if db_name in defines.EDGEDB_SPECIAL_DBS: continue - dbs[db.name] = dict( - name=db.name, - dbver=db.dbver, - config=serialize_config(db.db_config), - extensions=sorted(db.extensions), - query_cache_size=db.get_query_cache_size(), - connections=[ - dict( - in_tx=view.in_tx(), - in_tx_error=view.in_tx_error(), - config=serialize_config(view.get_session_config()), - module_aliases=view.get_modaliases(), - ) - for view in db.iter_views() - ], - ) + ns = {} + for ns_name, ns_db in ns_map.items(): + ns[ns_name] = dict( + name=ns_db.name, + namespace=ns_name, + dbver=ns_db.dbver, + config=serialize_config(ns_db.db_config), + extensions=sorted(ns_db.extensions), + query_cache_size=ns_db.get_query_cache_size(), + connections=[ + dict( + in_tx=view.in_tx(), + in_tx_error=view.in_tx_error(), + config=serialize_config(view.get_session_config()), + module_aliases=view.get_modaliases(), + ) + for view in ns_db.iter_views() + ], + ) + + dbs[db_name] = dict(name=db_name, namespace=ns) obj['databases'] = dbs From 7c991f180a25bac11999ad9512707e91b1a2d42f Mon Sep 17 00:00:00 2001 From: JMPRIEST Date: Thu, 18 May 2023 17:14:22 +0800 Subject: [PATCH 11/20] =?UTF-8?q?:bug:=20=E5=8E=BB=E9=99=A4=E6=96=B0?= =?UTF-8?q?=E5=BB=BAnamespace=E5=92=8Cdb=E6=97=B6block=E4=BC=A0=E5=85=A5?= =?UTF-8?q?=E7=9A=84ns=E4=BF=A1=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- edb/server/compiler/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/edb/server/compiler/compiler.py b/edb/server/compiler/compiler.py index 389b6383951..414fb615185 100644 --- a/edb/server/compiler/compiler.py +++ b/edb/server/compiler/compiler.py @@ -428,7 +428,7 @@ def _process_delta(self, ctx: CompileContext, delta): ns_prefix = ctx.namespace + '_' if db_cmd or ns_cmd: - block = pg_dbops.SQLBlock(ns_prefix) + block = pg_dbops.SQLBlock() new_be_types = new_types = frozenset() else: block = pg_dbops.PLTopBlock(ns_prefix) From 55e26ad0273f5b411e8dd8cc19506413414fdefc Mon Sep 17 00:00:00 2001 From: JMPRIEST Date: Mon, 22 May 2023 12:02:31 +0800 Subject: [PATCH 12/20] =?UTF-8?q?:construction:=20=E7=BC=96=E5=86=99namesp?= =?UTF-8?q?ace=E4=BF=A1=E6=81=AF=E5=9C=A8compile=E8=BF=87=E7=A8=8B?= =?UTF-8?q?=E4=B8=AD=E8=B5=B7=E6=95=88=E7=9A=84=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- edb/graphql/extension.pyx | 37 +- edb/pgsql/codegen.py | 2 +- edb/pgsql/dbops/namespace.py | 4 +- edb/pgsql/delta.py | 9 +- edb/schema/defines.py | 1 + edb/schema/reflection/structure.py | 4 +- edb/server/compiler/compiler.py | 2 + edb/server/compiler/dbstate.py | 4 +- edb/server/compiler_pool/pool.py | 20 +- edb/server/compiler_pool/worker.py | 26 +- edb/server/dbview/dbview.pxd | 51 +-- edb/server/dbview/dbview.pyx | 484 ++++++++++++++++----------- edb/server/protocol/binary.pxd | 2 +- edb/server/protocol/binary.pyx | 17 +- edb/server/protocol/binary_v0.pyx | 43 ++- edb/server/protocol/edgeql_ext.pyx | 14 +- edb/server/protocol/execute.pyx | 34 +- edb/server/protocol/extern_obj.py | 3 + edb/server/protocol/infer_expr.py | 11 +- edb/server/protocol/notebook_ext.pyx | 3 +- edb/server/protocol/schema_info.py | 18 +- edb/server/protocol/system_api.py | 3 +- edb/server/server.py | 86 ++--- 23 files changed, 523 insertions(+), 355 deletions(-) diff --git a/edb/graphql/extension.pyx b/edb/graphql/extension.pyx index 00c20b0d78b..017fe4c5b30 100644 --- a/edb/graphql/extension.pyx +++ b/edb/graphql/extension.pyx @@ -37,7 +37,7 @@ from edb import _graphql_rewrite from edb import errors from edb.graphql import errors as gql_errors from edb.server.dbview cimport dbview -from edb.server import compiler +from edb.server import compiler, defines from edb.server import defines as edbdef from edb.server.pgcon import errors as pgerrors from edb.server.protocol import execute @@ -97,6 +97,7 @@ async def handle_request( globals = None query = None module = None + namespace = defines.DEFAULT_NS limit = 0 try: @@ -111,6 +112,7 @@ async def handle_request( variables = body.get('variables') module = body.get('module') limit = body.get('limit', 0) + namespace = body.get('namespace', defines.DEFAULT_NS) globals = body.get('globals') elif request.content_type == 'application/graphql': query = request.body.decode('utf-8') @@ -157,6 +159,12 @@ async def handle_request( else: limit = 0 + namespace = qs.get('namespace') + if namespace is not None: + namespace = namespace[0] + else: + namespace = defines.DEFAULT_NS + else: raise TypeError('expected a GET or a POST request') @@ -186,7 +194,7 @@ async def handle_request( response.content_type = b'application/json' try: result = await _execute( - db, server, query, + db, namespace, server, query, operation_name, variables, globals, query_only, module or None, limit ) @@ -216,6 +224,7 @@ async def handle_request( async def compile( db, + ns, server, query: str, tokens: Optional[List[Tuple[int, int, int, str]]], @@ -229,10 +238,10 @@ async def compile( compiler_pool = server.get_compiler_pool() return await compiler_pool.compile_graphql( db.name, - db.namespace, - db.user_schema, + ns.name, + ns.user_schema, server.get_global_schema(), - db.reflection_cache, + ns.reflection_cache, db.db_config, server.get_compilation_system_config(), query, @@ -247,9 +256,16 @@ async def compile( async def _execute( - db, server, query, operation_name, variables, + db, namespace, server, query, operation_name, variables, globals, query_only, module, limit ): + + if namespace not in db.ns_map: + raise errors.InternalServerError( + f'NameSpace: [{namespace}] not in current db(ver:{db.dbver})' + ) + ns = db.ns_map[namespace] + dbver = db.dbver query_cache = server._http_query_cache @@ -299,7 +315,7 @@ async def _execute( print(f'key_vars: {key_var_names}') print(f'variables: {vars}') - cache_key = ('graphql', prepared_query, key_vars, operation_name, dbver, query_only, module, limit) + cache_key = ('graphql', prepared_query, key_vars, operation_name, dbver, query_only, namespace, module, limit) use_prep_stmt = False entry: CacheEntry = None @@ -308,13 +324,14 @@ async def _execute( if isinstance(entry, CacheRedirect): key_vars2 = tuple(vars[k] for k in entry.key_vars) - cache_key2 = (prepared_query, key_vars2, operation_name, dbver, query_only, module, limit) + cache_key2 = (prepared_query, key_vars2, operation_name, dbver, query_only, namespace, module, limit) entry = query_cache.get(cache_key2, None) if entry is None: if rewritten is not None: qug, gql_op = await compile( db, + ns, server, query, rewritten.tokens(gql_lexer.TokenKind), @@ -328,6 +345,7 @@ async def _execute( else: qug, gql_op = await compile( db, + ns, server, query, None, @@ -373,8 +391,7 @@ async def _execute( dbv = await server.new_dbview( dbname=db.name, query_cache=False, - protocol_version=edbdef.CURRENT_PROTOCOL, - db=db.namespace + protocol_version=edbdef.CURRENT_PROTOCOL ) pgcon = await server.acquire_pgcon(db.name) diff --git a/edb/pgsql/codegen.py b/edb/pgsql/codegen.py index 8aa0026fed4..abfc80bb8d5 100644 --- a/edb/pgsql/codegen.py +++ b/edb/pgsql/codegen.py @@ -122,7 +122,7 @@ def visit_Relation(self, node): else: if self.namespace == defines.DEFAULT_NS: self.write(common.qname(node.schemaname, node.name)) - elif node.schemaname in ['edgedbext', 'edgedb', 'edgedbss', 'edgedbpub', 'edgedbstd', 'edgedbinstdata']: + elif node.schemaname in defines.EDGEDB_OWNED_DBS: self.write(common.qname(f"{self.namespace}_{node.schemaname}", node.name)) else: self.write(common.qname(node.schemaname, node.name)) diff --git a/edb/pgsql/dbops/namespace.py b/edb/pgsql/dbops/namespace.py index 6cb15ea37e4..678528886d6 100644 --- a/edb/pgsql/dbops/namespace.py +++ b/edb/pgsql/dbops/namespace.py @@ -24,7 +24,7 @@ from . import base from . import ddl from edb.pgsql.common import quote_ident as qi -from edb.schema.defines import DEFAULT_NS +from edb.schema.defines import DEFAULT_NS, EDGEDB_OWNED_DBS class NameSpace(base.DBObject): @@ -63,7 +63,7 @@ def code(self, block: base.PLBlock) -> str: schemas = ",".join( [ qi(f"{self.name}_{schema}") - for schema in ['edgedbext', 'edgedb', 'edgedbss', 'edgedbpub', 'edgedbstd', 'edgedbinstdata', ] + for schema in EDGEDB_OWNED_DBS ] ) return f'DROP SCHEMA {schemas} CASCADE;' diff --git a/edb/pgsql/delta.py b/edb/pgsql/delta.py index 6a414d5e012..296053b8273 100644 --- a/edb/pgsql/delta.py +++ b/edb/pgsql/delta.py @@ -5849,7 +5849,7 @@ def _get_dunder_type_trigger_proc_text(self, target, *, schema, namespace): MESSAGE = 'deletion of {tgtname} (' || OLD.id || ') is prohibited by link target policy', DETAIL = 'Object is still referenced in link __type__' - || ' of ' || {ns_prefix}._get_schema_object_name(OLD.id) || ' (' + || ' of ' || {ns_prefix}edgedb._get_schema_object_name(OLD.id) || ' (' || OLD.id || ').'; END IF; '''.format(tgtname=target.get_displayname(schema), ns_prefix=ns_prefix)) @@ -6200,8 +6200,12 @@ def _resolve_link_group( return tuple(group_var) def _get_inline_link_trigger_proc_text( - self, target, links, *, disposition, schema + self, target, links, *, disposition, schema, namespace ): + if namespace == defines.DEFAULT_NS: + ns_prefix = '' + else: + ns_prefix = namespace + '_' chunks = [] @@ -6279,6 +6283,7 @@ def _get_inline_link_trigger_proc_text( tgtname=target.get_displayname(schema), near_endpoint=near_endpoint, far_endpoint=far_endpoint, + ns_prefix=ns_prefix ) chunks.append(text) diff --git a/edb/schema/defines.py b/edb/schema/defines.py index 0da8afb396e..581f57f9f00 100644 --- a/edb/schema/defines.py +++ b/edb/schema/defines.py @@ -35,5 +35,6 @@ EDGEDB_SYSTEM_DB = '__edgedbsys__' EDGEDB_SPECIAL_DBS = {EDGEDB_TEMPLATE_DB, EDGEDB_SYSTEM_DB} +EDGEDB_OWNED_DBS = ['edgedbext', 'edgedb', 'edgedbss', 'edgedbpub', 'edgedbstd', 'edgedbinstdata'] DEFAULT_NS = 'default' diff --git a/edb/schema/reflection/structure.py b/edb/schema/reflection/structure.py index f4dce49b924..1eefe954974 100644 --- a/edb/schema/reflection/structure.py +++ b/edb/schema/reflection/structure.py @@ -36,6 +36,7 @@ from edb.schema import inheriting as s_inh from edb.schema import links as s_links from edb.schema import name as sn +from edb.schema import namespace as s_ns from edb.schema import objects as s_obj from edb.schema import objtypes as s_objtypes from edb.schema import schema as s_schema @@ -794,7 +795,8 @@ def generate_structure(schema: s_schema.Schema) -> SchemaReflectionParts: qry += ' FILTER NOT .builtin' if issubclass(py_cls, s_obj.GlobalObject): - global_parts.append(qry) + if not issubclass(py_cls, s_ns.NameSpace): + global_parts.append(qry) else: local_parts.append(qry) diff --git a/edb/server/compiler/compiler.py b/edb/server/compiler/compiler.py index 414fb615185..58d91453782 100644 --- a/edb/server/compiler/compiler.py +++ b/edb/server/compiler/compiler.py @@ -1997,6 +1997,7 @@ def _try_compile( raise errors.ProtocolError('nothing to compile') rv = dbstate.QueryUnitGroup() + rv.namespace = ctx.namespace is_script = statements_len > 1 script_info = None @@ -2035,6 +2036,7 @@ def _try_compile( cardinality=default_cardinality, capabilities=capabilities, output_format=stmt_ctx.output_format, + namespace=ctx.namespace ) if not comp.is_transactional: diff --git a/edb/server/compiler/dbstate.py b/edb/server/compiler/dbstate.py index 9a13ea91b5a..38919e654f5 100644 --- a/edb/server/compiler/dbstate.py +++ b/edb/server/compiler/dbstate.py @@ -39,7 +39,7 @@ from edb.schema import objects as s_obj from edb.schema import schema as s_schema -from edb.server import config +from edb.server import config, defines from . import enums from . import sertypes @@ -316,6 +316,7 @@ class QueryUnit: # schema reflection sqls, only available if this is a ddl stmt. schema_refl_sqls: Tuple[bytes, ...] = None stdview_sqls: Tuple[bytes, ...] = None + namespace: str = defines.DEFAULT_NS @property def has_ddl(self) -> bool: @@ -371,6 +372,7 @@ class QueryUnitGroup: ref_ids: Optional[Set[uuid.UUID]] = None # Record affected object ids for cache clear affected_obj_ids: Optional[Set[uuid.UUID]] = None + namespace: str = defines.DEFAULT_NS def __iter__(self): return iter(self.units) diff --git a/edb/server/compiler_pool/pool.py b/edb/server/compiler_pool/pool.py index fbb0c1dd12f..6ea2adc623d 100644 --- a/edb/server/compiler_pool/pool.py +++ b/edb/server/compiler_pool/pool.py @@ -303,9 +303,9 @@ def _get_init_args(self): def _get_init_args_uncached(self): dbs: state.DatabasesState = immutables.Map() - for db_name, ns_map in self._dbindex.iter_dbs(): + for db in self._dbindex.iter_dbs(): namespace = immutables.Map() - for ns_name, ns_db in ns_map.items(): + for ns_name, ns_db in db.ns_map.items(): db_user_schema = ns_db.user_schema version_id = UNKNOW_VER_ID if db_user_schema is None else db_user_schema.version_id namespace.set( @@ -316,10 +316,10 @@ def _get_init_args_uncached(self): user_schema=db_user_schema, user_schema_version=version_id, reflection_cache=ns_db.reflection_cache, - database_config=ns_db.db_config, + database_config=db.db_config, ) ) - dbs = dbs.set(db_name, namespace) + dbs = dbs.set(db.name, namespace) init_args = ( dbs, @@ -462,7 +462,7 @@ def sync_worker_state_cb( worker._system_config = system_config worker_db: state.DatabaseState = worker._dbs.get(dbname, {}).get(namespace) - preargs = (dbname,) + preargs = (dbname, namespace) to_update = {} if worker_db is None: @@ -1194,23 +1194,23 @@ def __init__(self, **kwargs): @functools.lru_cache(maxsize=None) def _get_init_args(self): dbs: state.DatabasesState = immutables.Map() - for db_name, ns_map in self._dbindex.iter_dbs(): + for db in self._dbindex.iter_dbs(): namespace = immutables.Map() - for ns_name, ns_db in ns_map.items(): + for ns_name, ns_db in db.ns_map.items(): db_user_schema = ns_db.user_schema version_id = UNKNOW_VER_ID if db_user_schema is None else db_user_schema.version_id namespace.set( ns_name, state.DatabaseState( - name=db_name, + name=db.name, namespace=ns_name, user_schema=db_user_schema, user_schema_version=version_id, reflection_cache=ns_db.reflection_cache, - database_config=ns_db.db_config, + database_config=db.db_config, ) ) - dbs = dbs.set(db_name, namespace) + dbs = dbs.set(db.name, namespace) self._worker._dbs = dbs self._worker._backend_runtime_params = self._backend_runtime_params self._worker._std_schema = self._std_schema diff --git a/edb/server/compiler_pool/worker.py b/edb/server/compiler_pool/worker.py index a373f8f2ee6..a26a0e37baa 100644 --- a/edb/server/compiler_pool/worker.py +++ b/edb/server/compiler_pool/worker.py @@ -123,8 +123,8 @@ def __sync__( global INSTANCE_CONFIG try: - db = DBS.get(dbname, {}).get(namespace) - if db is None: + ns_db = DBS.get(dbname, {}).get(namespace) + if ns_db is None: assert user_schema is not None assert reflection_cache is not None assert database_config is not None @@ -132,17 +132,15 @@ def __sync__( reflection_cache_unpacked = pickle.loads(reflection_cache) database_config_unpacked = pickle.loads(database_config) ns = DBS.get(dbname, immutables.Map()) - ns.set( + ns_db = state.DatabaseState( + dbname, namespace, - state.DatabaseState( - dbname, - namespace, - user_schema_unpacked, - user_schema_unpacked.version_id, - reflection_cache_unpacked, - database_config_unpacked, - ) + user_schema_unpacked, + user_schema_unpacked.version_id, + reflection_cache_unpacked, + database_config_unpacked, ) + ns.set(namespace, ns_db) DBS = DBS.set(dbname, ns) else: updates = {} @@ -162,9 +160,9 @@ def __sync__( updates['database_config'] = pickle.loads(database_config) if updates: - db = db._replace(**updates) + ns_db = ns_db._replace(**updates) ns = DBS[dbname] - ns.set(namespace, db) + ns.set(namespace, ns_db) DBS = DBS.set(dbname, ns) if global_schema is not None: @@ -178,7 +176,7 @@ def __sync__( f'failed to sync worker state: {type(ex).__name__}({ex})') from ex if need_return: - return db + return ns_db def compile( diff --git a/edb/server/dbview/dbview.pxd b/edb/server/dbview/dbview.pxd index 0fa4e6f8af5..44c15e14352 100644 --- a/edb/server/dbview/dbview.pxd +++ b/edb/server/dbview/dbview.pxd @@ -46,6 +46,7 @@ cdef class QueryRequestInfo: cdef public bint inline_objectids cdef public uint64_t allow_capabilities cdef public object module + cdef public object namespace cdef public bint read_only cdef public object external_view cdef public bint testmode @@ -71,12 +72,26 @@ cdef class DatabaseIndex: object _factory +cdef class NameSpace: + cdef: + public object _eql_to_compiled + public object _eql_to_compiled_disk + public object _object_id_to_eql + DatabaseIndex _dbindex + object _state_serializers + str _sql_bak_dir + bint _log_cache + + readonly str name + public object user_schema + public object reflection_cache + public object backend_ids + public object extensions + + cdef class Database: cdef: - object _eql_to_compiled - object _eql_to_compiled_disk - object _object_id_to_eql DatabaseIndex _index object _views object _introspection_lock @@ -85,30 +100,27 @@ cdef class Database: bint _log_cache readonly str name - readonly str namespace + public object ns_map readonly object dbver readonly object db_config - readonly object user_schema - readonly object reflection_cache - readonly object backend_ids - readonly object extensions cdef schedule_config_update(self) - cdef _invalidate_caches(self, drop_ids) + cdef _invalidate_caches(self) cdef _cache_compiled_query(self, key, query_unit) cdef _new_view(self, query_cache, protocol_version) cdef _remove_view(self, view) - cdef _update_backend_ids(self, new_types) + cdef _update_backend_ids(self, namespace, new_types) cdef _set_and_signal_new_user_schema( self, + namespace, new_schema, reflection_cache=?, backend_ids=?, db_config=?, affecting_ids=?, ) - cdef get_state_serializer(self, protocol_version) + cdef get_state_serializer(self, namespace, protocol_version) cdef class DatabaseConnectionView: @@ -138,8 +150,6 @@ cdef class DatabaseConnectionView: tuple _session_state_db_cache tuple _session_state_cache - object _eql_to_compiled - object _txid object _in_tx_db_config object _in_tx_savepoints @@ -167,12 +177,11 @@ cdef class DatabaseConnectionView: object __weakref__ - cdef _invalidate_local_cache(self) cdef _reset_tx_state(self) cdef clear_tx_error(self) cdef rollback_tx_to_savepoint(self, name) - cdef declare_savepoint(self, name, spid) + cdef declare_savepoint(self, namespace, name, spid) cdef recover_aliases_and_config(self, modaliases, config, globals) cdef abort_tx(self) @@ -185,12 +194,12 @@ cdef class DatabaseConnectionView: cdef tx_error(self) cdef start(self, query_unit) - cdef _start_tx(self) + cdef _start_tx(self, namespace) cdef _apply_in_tx(self, query_unit) cdef start_implicit(self, query_unit) cdef on_error(self) cdef commit_implicit_tx( - self, user_schema, user_schema_unpacked, + self, namespace, user_schema, user_schema_unpacked, user_schema_mutation, global_schema, cached_reflection, affecting_ids, ) @@ -201,7 +210,7 @@ cdef class DatabaseConnectionView: cpdef get_globals(self) cpdef set_globals(self, new_globals) - cdef get_state_serializer(self) + cdef get_state_serializer(self, namespace) cdef set_state_serializer(self, new_serializer) cdef update_database_config(self) @@ -215,8 +224,8 @@ cdef class DatabaseConnectionView: cpdef get_modaliases(self) cdef bytes serialize_state(self) - cdef bint is_state_desc_changed(self) - cdef describe_state(self) + cdef bint is_state_desc_changed(self, namespace=?) + cdef describe_state(self, namespace) cdef encode_state(self) - cdef decode_state(self, type_id, data) + cdef decode_state(self, type_id, data, namespace=?) cdef inline recode_global(self, serializer, k, v) diff --git a/edb/server/dbview/dbview.pyx b/edb/server/dbview/dbview.pyx index bd8d2c7b402..bc7cf847523 100644 --- a/edb/server/dbview/dbview.pyx +++ b/edb/server/dbview/dbview.pyx @@ -88,6 +88,7 @@ cdef class QueryRequestInfo: read_only: bint = False, testmode: bint = False, external_view: object = immutables.Map(), + namespace: str = defines.DEFAULT_NS, ): self.source = source self.protocol_version = protocol_version @@ -104,6 +105,7 @@ cdef class QueryRequestInfo: self.read_only = read_only self.testmode = testmode self.external_view = external_view + self.namespace = namespace self.cached_hash = hash(( self.source.cache_key(), @@ -118,7 +120,8 @@ cdef class QueryRequestInfo: self.inline_objectids, self.module, self.read_only, - self.testmode + self.testmode, + self.namespace )) def __hash__(self): @@ -138,7 +141,8 @@ cdef class QueryRequestInfo: self.inline_objectids == other.inline_objectids and self.module == other.module and self.read_only == other.read_only and - self.testmode == other.testmode + self.testmode == other.testmode and + self.namespace == other.namespace ) @@ -263,8 +267,7 @@ cdef format_eqls(raw_eqls): return "\n".join(msg) -cdef class Database: - +cdef class NameSpace: # Global LRU cache of compiled anonymous queries _eql_to_compiled: typing.Mapping[ typing.Tuple[QueryRequestInfo, @@ -289,35 +292,29 @@ cdef class Database: typing.Optional[immutables.Map] ] ] - + # Dict for object id to eql + _object_id_to_eql: EqlDict[ + uuid.UUID, + typing.Tuple[QueryRequestInfo, + typing.Optional[immutables.Map], + typing.Optional[immutables.Map] + ] + ] def __init__( self, - DatabaseIndex index, str name, - str namespace, + DatabaseIndex dbindex, *, object user_schema, - object db_config, object reflection_cache, object backend_ids, object extensions, ): self.name = name - self.namespace = namespace - - self.dbver = next_dbver() - - self._index = index - self._views = weakref.WeakSet() self._state_serializers = {} - - self._introspection_lock = asyncio.Lock() - self._eql_to_compiled = lru.LRUMapping(maxsize=defines._MAX_QUERIES_CACHE) self._eql_to_compiled_disk = RankedDiskCache() self._object_id_to_eql = EqlDict() - - self.db_config = db_config self.user_schema = user_schema self.reflection_cache = reflection_cache self.backend_ids = backend_ids @@ -328,58 +325,12 @@ cdef class Database: } else: self.extensions = extensions - - self._sql_bak_dir = os.path.join(self.server._runstate_dir, 'sql_bak') + self._dbindex = dbindex + self._sql_bak_dir = os.path.join(dbindex._server._runstate_dir, 'sql_bak', name) self._log_cache = logger.isEnabledFor(logging.DEBUG) and debug.flags.show_cache_info - @property - def server(self): - return self._index._server - - cdef schedule_config_update(self): - self._index._server._on_local_database_config_change(self.name) - - cdef _set_and_signal_new_user_schema( - self, - new_schema, - reflection_cache=None, - backend_ids=None, - db_config=None, - affecting_ids: typing.Set[uuid.UUID]=None - ): - if new_schema is None: - raise AssertionError('new_schema is not supposed to be None') - - self.dbver = next_dbver() - - self.user_schema = new_schema - - self.extensions = { - ext.get_name(new_schema).name - for ext in new_schema.get_objects(type=s_ext.Extension) - } - - if backend_ids is not None: - self.backend_ids = backend_ids - if reflection_cache is not None: - self.reflection_cache = reflection_cache - if db_config is not None: - self.db_config = db_config - - drop_ids = {DROP_IN_SCHEMA_DELTA} - - if affecting_ids: - drop_ids.update(affecting_ids.intersection(self._object_id_to_eql.keys())) - - self._invalidate_caches(drop_ids) - - cdef _update_backend_ids(self, new_types): - self.backend_ids.update(new_types) - - cdef _invalidate_caches(self, drop_ids: typing.Set[uuid.UUID]): + def invalidate_caches(self, drop_ids: typing.Set[uuid.UUID]): self._state_serializers.clear() - self._clear_http_cache() - if self._log_cache: logger.debug(f'Ids to drop: {drop_ids}.') @@ -390,13 +341,17 @@ cdef class Database: for eql in list(self._object_id_to_eql[obj_id]): if eql in self._eql_to_compiled: if self._log_cache: - logger.debug(f"Eql with sql:{format_eqls((eql,))} " - f"will be dropped for change of object with id <{obj_id}> in LRU cache.") + logger.debug( + f"Eql with sql:{format_eqls((eql,))} " + f"will be dropped for change of object with id <{obj_id}> in LRU cache." + ) del self._eql_to_compiled[eql] if eql in self._eql_to_compiled_disk: if self._log_cache: - logger.debug(f"Eql with sql:{format_eqls((eql,))} " - f"will be dropped for change of object with id <{obj_id}> in Disk cache.") + logger.debug( + f"Eql with sql:{format_eqls((eql,))} " + f"will be dropped for change of object with id <{obj_id}> in Disk cache." + ) del self._eql_to_compiled_disk[eql] del self._object_id_to_eql[obj_id] @@ -404,34 +359,21 @@ cdef class Database: if self._log_cache: logger.debug('After invalidate, LRU Cache: \n' + format_eqls(self._eql_to_compiled._dict.keys())) logger.debug('Disk Cache: \n' + format_eqls(self._eql_to_compiled_disk.keys())) - logger.debug(f'Http CACHE: \n{self.server._http_query_cache._dict.keys()}') logger.debug(f'Obj id to Eql: \n{self._object_id_to_eql}') - def _clear_http_cache(self): - query_cache = self.server._http_query_cache - for cache_key in self.server.remove_on_ddl: - if cache_key in query_cache: - del query_cache[cache_key] - self.server.remove_on_ddl.clear() - def clear_caches(self): self._eql_to_compiled.clear() self._eql_to_compiled_disk.clear() self._object_id_to_eql.clear() - query_cache = self.server._http_query_cache - for cache_key in dict(query_cache._dict): - del query_cache[cache_key] - self.server.remove_on_ddl.clear() def view_caches(self): return '\n\n'.join([ f'LRU CACHE: \n{format_eqls(self._eql_to_compiled._dict.keys())}', f'Disk CACHE: \n{format_eqls(self._eql_to_compiled_disk.keys())}', - f'Http CACHE: \n{self.server._http_query_cache._dict.keys()}', f'Obj id to Eql: \n{self._object_id_to_eql}', ]) - cdef _cache_compiled_query(self, key, compiled: dbstate.QueryUnitGroup): + def _cache_compiled_query(self, key, compiled: dbstate.QueryUnitGroup): assert compiled.cacheable existing = self._eql_to_compiled.get(key) @@ -477,6 +419,150 @@ cdef class Database: for obj_id in compiled.ref_ids: self._object_id_to_eql.add(obj_id, key) + def get_state_serializer(self, protocol_version): + if protocol_version not in self._state_serializers: + self._state_serializers[protocol_version] = self._dbindex._factory.make( + self.user_schema, + self._dbindex._global_schema, + protocol_version, + ) + return self._state_serializers[protocol_version] + + def get_query_cache_size(self): + return len(self._eql_to_compiled) + + + +cdef class Database: + + def __init__( + self, + DatabaseIndex index, + str name, + str namespace, + *, + object user_schema, + object db_config, + object reflection_cache, + object backend_ids, + object extensions, + ): + self.name = name + self.dbver = next_dbver() + + self._index = index + self._views = weakref.WeakSet() + self._state_serializers = {} + + self._introspection_lock = asyncio.Lock() + + self.ns_map: typing.Dict[str, NameSpace] = { + namespace: NameSpace( + name=namespace, + dbindex=index, + user_schema=user_schema, + reflection_cache=reflection_cache, + backend_ids=backend_ids, + extensions=extensions + ) + } + self.db_config = db_config + self._log_cache = logger.isEnabledFor(logging.DEBUG) and debug.flags.show_cache_info + + @property + def server(self): + return self._index._server + + cdef schedule_config_update(self): + self._index._server._on_local_database_config_change(self.name) + + cdef _set_and_signal_new_user_schema( + self, + namespace, + new_schema, + reflection_cache=None, + backend_ids=None, + db_config=None, + affecting_ids: typing.Set[uuid.UUID]=None + ): + if new_schema is None: + raise AssertionError('new_schema is not supposed to be None') + + self.dbver = next_dbver() + if db_config is not None: + self.db_config = db_config + + extensions = { + ext.get_name(new_schema).name + for ext in new_schema.get_objects(type=s_ext.Extension) + } + if namespace not in self.ns_map: + ns = self.ns_map[namespace] = NameSpace( + name=namespace, + dbindex=self._index, + user_schema=new_schema, + reflection_cache=reflection_cache, + backend_ids=backend_ids, + extensions=extensions + ) + else: + ns = self.ns_map[namespace] + ns.user_schema = new_schema + ns.extensions = extensions + if reflection_cache is not None: + ns.reflection_cache = reflection_cache + if backend_ids is not None: + ns.backend_ids = backend_ids + drop_ids = {DROP_IN_SCHEMA_DELTA} + + if affecting_ids: + drop_ids.update(affecting_ids.intersection(ns._object_id_to_eql.keys())) + + ns.invalidate_caches(drop_ids) + + self._invalidate_caches() + + cdef _update_backend_ids(self, namespace, new_types): + self.ns_map[namespace].backend_ids.update(new_types) + + cdef _invalidate_caches(self): + self._clear_http_cache() + if self._log_cache: + logger.debug(f'Http CACHE: \n{self.server._http_query_cache._dict.keys()}') + + def _clear_http_cache(self): + query_cache = self.server._http_query_cache + for cache_key in self.server.remove_on_ddl: + if cache_key in query_cache: + del query_cache[cache_key] + self.server.remove_on_ddl.clear() + + def clear_caches(self): + for ns in self.ns_map.values(): + ns.clear_caches() + query_cache = self.server._http_query_cache + for cache_key in dict(query_cache._dict): + del query_cache[cache_key] + self.server.remove_on_ddl.clear() + + def view_caches(self): + return f'Http CACHE: \n{self.server._http_query_cache._dict.keys()}'\ + + '\n\n'.join( + [ + f"NameSpace({ns.name}): \n{ns.view_caches()}" + for ns in self.ns_map.values() + ] + ) + + cdef _cache_compiled_query(self, key, compiled: dbstate.QueryUnitGroup): + assert compiled.cacheable + + if compiled.namespace not in self.ns_map: + return + + ns = self.ns_map[compiled.namespace] + return ns._cache_compiled_query(key, compiled) + cdef _new_view(self, query_cache, protocol_version): view = DatabaseConnectionView( self, query_cache=query_cache, protocol_version=protocol_version @@ -487,33 +573,27 @@ cdef class Database: cdef _remove_view(self, view): self._views.remove(view) - cdef get_state_serializer(self, protocol_version): - if protocol_version not in self._state_serializers: - self._state_serializers[protocol_version] = self._index._factory.make( - self.user_schema, - self._index._global_schema, - protocol_version, + cdef get_state_serializer(self, namespace, protocol_version): + if namespace not in self.ns_map: + raise errors.InternalServerError( + f'NameSpace: [{namespace}] not in current db(ver:{self._db.dbver})' ) - return self._state_serializers[protocol_version] + return self.ns_map[namespace].get_state_serializer(protocol_version) def iter_views(self): yield from self._views - def get_query_cache_size(self): - return len(self._eql_to_compiled) - async def introspection(self): - if self.user_schema is None: + if any(ns.user_schema is None for ns in self.ns_map.values()): async with self._introspection_lock: - if self.user_schema is None: - await self._index._server.introspect_db(self.name, self.namespace) + await self._index._server.introspect(self.name) - async def persist_schema(self): + async def persist_schema(self, namespace): async with self._introspection_lock: - await self._index._server.persist_user_schema(self.name, self.namespace) + await self._index._server.persist_user_schema(self.name, namespace) - def schedule_schema_persistence(self): - asyncio.create_task(self.persist_schema()) + def schedule_schema_persistence(self, namespace): + asyncio.create_task(self.persist_schema(namespace)) def schedule_stdobj_inhview_update(self, sql): asyncio.create_task( @@ -522,13 +602,6 @@ cdef class Database: ) cdef class DatabaseConnectionView: - - _eql_to_compiled: typing.Mapping[ - typing.Tuple[QueryRequestInfo, - typing.Optional[immutables.Map], - typing.Optional[immutables.Map]], - dbstate.QueryUnitGroup] - def __init__(self, db: Database, *, query_cache, protocol_version): self._db = db @@ -557,15 +630,8 @@ cdef class DatabaseConnectionView: self._last_comp_state = None self._last_comp_state_id = None - # Whenever we are in a transaction that had executed a - # DDL command, we use this cache for compiled queries. - self._eql_to_compiled = lru.LRUMapping(maxsize=defines._MAX_QUERIES_CACHE) - self._reset_tx_state() - cdef _invalidate_local_cache(self): - self._eql_to_compiled.clear() - cdef _reset_tx_state(self): self._txid = None self._in_tx = False @@ -590,7 +656,6 @@ cdef class DatabaseConnectionView: self._in_tx_dbver = 0 self._in_tx_stdview_sqls = None self._in_tx_sp_sqls = [] - self._invalidate_local_cache() cdef clear_tx_error(self): self._tx_error = False @@ -615,14 +680,13 @@ cdef class DatabaseConnectionView: self.set_session_config(config) self.set_globals(globals) self.set_state_serializer(state_serializer) - self._invalidate_local_cache() - cdef declare_savepoint(self, name, spid): + cdef declare_savepoint(self, namespace, name, spid): state = ( self.get_modaliases(), self.get_session_config(), self.get_globals(), - self.get_state_serializer(), + self.get_state_serializer(namespace), ) self._in_tx_savepoints.append((name, spid, state)) @@ -649,12 +713,12 @@ cdef class DatabaseConnectionView: else: return self._globals - cdef get_state_serializer(self): + cdef get_state_serializer(self, namespace): if self._in_tx: if self._in_tx_state_serializer is None: # DDL in transaction, recalculate the state descriptor self._in_tx_state_serializer = self._db._index._factory.make( - self.get_user_schema(), + self.get_user_schema(namespace), self.get_global_schema(), self._protocol_version, ) @@ -663,6 +727,7 @@ cdef class DatabaseConnectionView: if self._state_serializer is None: # Executed a DDL, recalculate the state descriptor self._state_serializer = self._db.get_state_serializer( + namespace, self._protocol_version ) return self._state_serializer @@ -743,7 +808,7 @@ cdef class DatabaseConnectionView: else: return self._modaliases - def get_user_schema(self): + def get_user_schema(self, namespace: str): if self._in_tx: if self._in_tx_user_schema_mut_pickled: mutation = pickle.loads(self._in_tx_user_schema_mut_pickled) @@ -751,7 +816,18 @@ cdef class DatabaseConnectionView: self._in_tx_user_schema_mut_pickled = None return self._in_tx_user_schema else: - return self._db.user_schema + if namespace in self._db.ns_map: + return self._db.ns_map[namespace].user_schema + raise errors.InternalServerError( + f'NameSpace: [{namespace}] not in current db(ver:{self._db.dbver})' + ) + + def get_reflection_cache(self, namespace: str): + if namespace in self._db.ns_map: + return self._db.ns_map[namespace].reflection_cache + raise errors.InternalServerError( + f'NameSpace: [{namespace}] not in current db(ver:{self._db.dbver})' + ) def get_global_schema(self): if self._in_tx: @@ -764,15 +840,15 @@ cdef class DatabaseConnectionView: else: return self._db._index._global_schema - def get_schema(self): - user_schema = self.get_user_schema() + def get_schema(self, namespace): + user_schema = self.get_user_schema(namespace) return s_schema.ChainedSchema( self._db._index._std_schema, user_schema, self._db._index._global_schema, ) - def resolve_backend_type_id(self, type_id): + def resolve_backend_type_id(self, type_id, namespace: defines.DEFAULT_NS): type_id = str(type_id) if self._in_tx: @@ -781,7 +857,12 @@ cdef class DatabaseConnectionView: except KeyError: pass - tid = self._db.backend_ids.get(type_id) + if namespace not in self._db.ns_map: + raise errors.InternalServerError( + f'NameSpace: [{namespace}] not in current db(ver:{self._db.dbver})' + ) + + tid = self._db.ns_map[namespace].backend_ids.get(type_id) if tid is None: raise RuntimeError( f'cannot resolve backend OID for type {type_id}') @@ -812,8 +893,8 @@ cdef class DatabaseConnectionView: self._session_state_db_cache = (self._config, spec) return spec - cdef bint is_state_desc_changed(self): - serializer = self.get_state_serializer() + cdef bint is_state_desc_changed(self, namespace=defines.DEFAULT_NS): + serializer = self.get_state_serializer(namespace) if not self._in_tx: # We may have executed a query, or COMMIT/ROLLBACK - just use # the serializer we preserved before. NOTE: the schema might @@ -840,8 +921,8 @@ cdef class DatabaseConnectionView: return True - cdef describe_state(self): - return self.get_state_serializer().describe() + cdef describe_state(self, namespace): + return self.get_state_serializer(namespace).describe() cdef encode_state(self): modaliases = self.get_modaliases() @@ -883,11 +964,11 @@ cdef class DatabaseConnectionView: state['globals'] = {k: v.value for k, v in globals_.items()} return serializer.type_id, serializer.encode(state) - cdef decode_state(self, type_id, data): + cdef decode_state(self, type_id, data, namespace=defines.DEFAULT_NS): if not self._in_tx: # make sure we start clean self._state_serializer = None - serializer = self.get_state_serializer() + serializer = self.get_state_serializer(namespace) self._command_state_serializer = serializer if type_id == sertypes.NULL_TYPE_ID.bytes: @@ -954,14 +1035,6 @@ cdef class DatabaseConnectionView: def __get__(self): return self._db.name - property namespace: - def __get__(self): - return self._db.namespace - - property reflection_cache: - def __get__(self): - return self._db.reflection_cache - property dbver: def __get__(self): if self._in_tx and self._in_tx_dbver: @@ -983,9 +1056,7 @@ cdef class DatabaseConnectionView: key = (key, self.get_modaliases(), self.get_session_config()) - if self._in_tx_with_ddl: - self._eql_to_compiled[key] = query_unit_group - else: + if not self._in_tx_with_ddl: self._db._cache_compiled_query(key, query_unit_group) cdef lookup_compiled_query(self, object key): @@ -994,12 +1065,19 @@ cdef class DatabaseConnectionView: self._in_tx_with_ddl): return None + if key.namespace not in self._db.ns_map: + raise errors.InternalServerError( + f'NameSpace: [{key.namespace}] not in current db(ver:{self._db.dbver})' + ) + + ns = self._db.ns_map[key.namespace] + key = (key, self.get_modaliases(), self.get_session_config()) - query_unit_group = self._db._eql_to_compiled.get(key) + query_unit_group = ns._eql_to_compiled.get(key) if query_unit_group is None: - disk_filepath = self._db._eql_to_compiled_disk.get(key) + disk_filepath = ns._eql_to_compiled_disk.get(key) if disk_filepath is None: return None @@ -1008,9 +1086,9 @@ cdef class DatabaseConnectionView: if logger.isEnabledFor(logging.DEBUG): logger.debug(f'Find dumped sql bytes in disk deleted, ' f'drop Eql for Sql: {key[0].source.text()}.') - self._db._eql_to_compiled_disk.delete_with_cb( + ns._eql_to_compiled_disk.delete_with_cb( key, - self._db._object_id_to_eql.maybe_drop_with_eqls + ns._object_id_to_eql.maybe_drop_with_eqls ) return None @@ -1019,7 +1097,7 @@ cdef class DatabaseConnectionView: query_unit_group = pickle.load(disk_file) metrics.edgeql_cache_pickle_load_duration.observe(time.monotonic() - started_at) - self._db._eql_to_compiled[key] = query_unit_group + ns._eql_to_compiled[key] = query_unit_group return query_unit_group @@ -1033,7 +1111,7 @@ cdef class DatabaseConnectionView: if query_unit.tx_id is not None: self._txid = query_unit.tx_id - self._start_tx() + self._start_tx(query_unit.namespace) if self._in_tx and not self._txid: raise errors.InternalServerError('unset txid in transaction') @@ -1041,22 +1119,17 @@ cdef class DatabaseConnectionView: if self._in_tx: self._apply_in_tx(query_unit) - cdef _start_tx(self): + cdef _start_tx(self, namespace): self._in_tx = True self._in_tx_config = self._config self._in_tx_globals = self._globals self._in_tx_db_config = self._db.db_config self._in_tx_modaliases = self._modaliases - self._in_tx_base_user_schema = self._db.user_schema - self._in_tx_user_schema = self._db.user_schema + self._in_tx_base_user_schema = self._db.ns_map[namespace].user_schema + self._in_tx_user_schema = self._db.ns_map[namespace].user_schema self._in_tx_global_schema = self._db._index._global_schema self._in_tx_state_serializer = self._state_serializer - def sync_tx_base_schema(self): - if self._db.user_schema is self._in_tx_base_user_schema: - return - self._in_tx_base_user_schema = self._db.user_schema - cdef _apply_in_tx(self, query_unit): if query_unit.has_ddl: self._in_tx_with_ddl = True @@ -1085,7 +1158,7 @@ cdef class DatabaseConnectionView: self.raise_in_tx_error() if not self._in_tx: - self._start_tx() + self._start_tx(query_unit.namespace) self._apply_in_tx(query_unit) @@ -1099,15 +1172,15 @@ cdef class DatabaseConnectionView: await be_conn.sql_execute(sqls) self._in_tx_sp_sqls.clear() - def save_schema_mutation(self, mut, mut_bytes): + def save_schema_mutation(self, namespace, mut, mut_bytes): self._db._index._server.get_compiler_pool().append_schema_mutation( self.dbname, - self.namespace, + namespace, mut_bytes, mut, - self.get_user_schema(), + self.get_user_schema(namespace), self.get_global_schema(), - self.reflection_cache, + self.get_reflection_cache(namespace), self.get_database_config(), self.get_compilation_system_config(), ) @@ -1122,6 +1195,7 @@ cdef class DatabaseConnectionView: and (side_effects & SideEffects.SchemaChanges) ): self.save_schema_mutation( + query_unit.namespace, query_unit.user_schema_mutation_obj, query_unit.user_schema_mutation, ) @@ -1130,19 +1204,15 @@ cdef class DatabaseConnectionView: def _on_success(self, query_unit, new_types): side_effects = 0 - if query_unit.tx_savepoint_rollback: - # Need to invalidate the cache in case there were - # SET ALIAS or CONFIGURE or DDL commands. - self._invalidate_local_cache() - if not self._in_tx: if new_types: - self._db._update_backend_ids(new_types) + self._db._update_backend_ids(query_unit.namespace, new_types) if query_unit.user_schema_mutation is not None: self._in_tx_dbver = next_dbver() self._state_serializer = None self._db._set_and_signal_new_user_schema( - query_unit.update_user_schema(self._db.user_schema), + query_unit.namespace, + query_unit.update_user_schema(self.get_user_schema(query_unit.namespace)), pickle.loads(query_unit.cached_reflection) if query_unit.cached_reflection is not None else None, @@ -1150,7 +1220,7 @@ cdef class DatabaseConnectionView: None, query_unit.affected_obj_ids ) - self._db.schedule_schema_persistence() + self._db.schedule_schema_persistence(query_unit.namespace) if query_unit.stdview_sqls: self._db.schedule_stdobj_inhview_update(query_unit.stdview_sqls) side_effects |= SideEffects.SchemaChanges @@ -1184,10 +1254,11 @@ cdef class DatabaseConnectionView: self._globals = self._in_tx_globals if self._in_tx_new_types: - self._db._update_backend_ids(self._in_tx_new_types) + self._db._update_backend_ids(query_unit.namespace, self._in_tx_new_types) if query_unit.user_schema_mutation is not None: self._state_serializer = None self._db._set_and_signal_new_user_schema( + query_unit.namespace, query_unit.update_user_schema(self._in_tx_base_user_schema), pickle.loads(query_unit.cached_reflection) if query_unit.cached_reflection is not None @@ -1196,7 +1267,7 @@ cdef class DatabaseConnectionView: None, query_unit.affected_obj_ids ) - self._db.schedule_schema_persistence() + self._db.schedule_schema_persistence(query_unit.namespace) if self._in_tx_stdview_sqls: self._db.schedule_stdobj_inhview_update(self._in_tx_stdview_sqls) side_effects |= SideEffects.SchemaChanges @@ -1225,7 +1296,7 @@ cdef class DatabaseConnectionView: return side_effects cdef commit_implicit_tx( - self, user_schema, user_schema_unpacked, + self, namespace, user_schema, user_schema_unpacked, user_schema_mutation, global_schema, cached_reflection, affecting_ids ): @@ -1237,7 +1308,7 @@ cdef class DatabaseConnectionView: self._globals = self._in_tx_globals if self._in_tx_new_types: - self._db._update_backend_ids(self._in_tx_new_types) + self._db._update_backend_ids(namespace, self._in_tx_new_types) if ( user_schema is not None @@ -1247,13 +1318,14 @@ cdef class DatabaseConnectionView: if user_schema_unpacked is not None: user_schema = user_schema_unpacked elif user_schema_mutation is not None: - base_user_schema = self._db.user_schema + base_user_schema = self.get_user_schema(namespace) user_schema = user_schema_mutation.apply(base_user_schema) else: user_schema = pickle.loads(user_schema) self._state_serializer = None self._db._set_and_signal_new_user_schema( + namespace, user_schema, pickle.loads(cached_reflection) if cached_reflection is not None @@ -1380,7 +1452,7 @@ cdef class DatabaseConnectionView: if self.in_tx(): result = await compiler_pool.compile_in_tx( self.dbname, - self.namespace, + query_req.namespace, self.txid, self._last_comp_state, self._last_comp_state_id, @@ -1403,10 +1475,10 @@ cdef class DatabaseConnectionView: else: result = await compiler_pool.compile( self.dbname, - self.namespace, - self.get_user_schema(), + query_req.namespace, + self.get_user_schema(query_req.namespace), self.get_global_schema(), - self.reflection_cache, + self.get_reflection_cache(query_req.namespace), self.get_database_config(), self.get_compilation_system_config(), query_req.source, @@ -1470,7 +1542,7 @@ cdef class DatabaseIndex: except KeyError: return 0 - return sum(len(ns._views) for ns in db.values()) + return len((db)._views) def get_sys_config(self): return self._sys_config @@ -1485,15 +1557,14 @@ cdef class DatabaseIndex: def has_db(self, dbname): return dbname in self._dbs - def get_db(self, dbname, namespace): + def get_db(self, dbname): try: - return self._dbs[dbname][namespace] + return self._dbs[dbname] except KeyError: - raise errors.UnknownDatabaseError( - f'database {dbname!r} (namespace: {namespace}) does not exist') + raise errors.UnknownDatabaseError(f'database {dbname!r} does not exist') - def maybe_get_db(self, dbname, namespace): - return self._dbs.get(dbname, {}).get(namespace) + def maybe_get_db(self, dbname): + return self._dbs.get(dbname) def get_global_schema(self): return self._global_schema @@ -1513,13 +1584,23 @@ cdef class DatabaseIndex: extensions=None, ): cdef Database db - db = self._dbs.get(dbname, {}).get(namespace) + db = self._dbs.get(dbname) if db is not None: - db._set_and_signal_new_user_schema( - user_schema, reflection_cache, backend_ids, db_config - ) + if namespace not in db.ns_map: + db.ns_map[namespace] = NameSpace( + name=namespace, + dbindex=db._index, + user_schema=user_schema, + reflection_cache=reflection_cache, + backend_ids=backend_ids, + extensions=extensions + ) + else: + db._set_and_signal_new_user_schema( + namespace, user_schema, reflection_cache, backend_ids, db_config + ) else: - db = Database( + self._dbs[dbname] = Database( self, dbname, namespace=namespace, @@ -1529,18 +1610,17 @@ cdef class DatabaseIndex: backend_ids=backend_ids, extensions=extensions, ) - ns_map = self._dbs.get(dbname, {}) - ns_map[namespace] = db - self._dbs[dbname] = ns_map def unregister_ns(self, dbname, namespace): - self._dbs.get(dbname, {}).pop(namespace) + if dbname not in self._dbs: + return + self._dbs[dbname].ns_map.pop(namespace) def unregister_db(self, dbname): self._dbs.pop(dbname) def iter_dbs(self): - return iter(self._dbs.items()) + return iter(self._dbs.values()) async def _save_system_overrides(self, conn): data = config.to_json( @@ -1607,10 +1687,10 @@ cdef class DatabaseIndex: await self._server._after_system_config_reset( op.setting_name) - def new_view(self, dbname: str, namespace: str, *, query_cache: bool, protocol_version): - db = self.get_db(dbname, namespace) + def new_view(self, dbname: str, *, query_cache: bool, protocol_version): + db = self.get_db(dbname) return (db)._new_view(query_cache, protocol_version) def remove_view(self, view: DatabaseConnectionView): - db = self.get_db(view.dbname, view.namespace) + db = self.get_db(view.dbname) return (db)._remove_view(view) diff --git a/edb/server/protocol/binary.pxd b/edb/server/protocol/binary.pxd index cf2c267642e..f6738b14f16 100644 --- a/edb/server/protocol/binary.pxd +++ b/edb/server/protocol/binary.pxd @@ -132,7 +132,7 @@ cdef class EdgeConnection(frontend.FrontendConnection): cdef WriteBuffer make_command_data_description_msg( self, dbview.CompiledQuery query ) - cdef WriteBuffer make_state_data_description_msg(self) + cdef WriteBuffer make_state_data_description_msg(self, namespace=?) cdef WriteBuffer make_command_complete_msg(self, capabilities, status) cdef inline reject_headers(self) diff --git a/edb/server/protocol/binary.pyx b/edb/server/protocol/binary.pyx index 1a37cc48b77..230ad1db1d7 100644 --- a/edb/server/protocol/binary.pyx +++ b/edb/server/protocol/binary.pyx @@ -515,7 +515,7 @@ cdef class EdgeConnection(frontend.FrontendConnection): msg_buf.write_bytes(b'\x00' * 32) msg_buf.end_message() buf.write_buffer(msg_buf) - + # TODO add namespace buf.write_buffer(self.make_state_data_description_msg()) self.write(buf) @@ -1005,10 +1005,10 @@ cdef class EdgeConnection(frontend.FrontendConnection): msg.end_message() return msg - cdef WriteBuffer make_state_data_description_msg(self): + cdef WriteBuffer make_state_data_description_msg(self, namespace=edbdef.DEFAULT_NS): cdef WriteBuffer msg - type_id, type_data = self.get_dbview().describe_state() + type_id, type_data = self.get_dbview().describe_state(namespace) msg = WriteBuffer.new_message(b's') msg.write_bytes(type_id.bytes) @@ -1092,7 +1092,8 @@ cdef class EdgeConnection(frontend.FrontendConnection): EdgeSeverity.EDGE_SEVERITY_NOTICE, errors.LogMessage.get_code(), 'server restart is required for the configuration ' - 'change to take effect') + 'change to take effect' + ) cdef dbview.QueryRequestInfo parse_execute_request(self): cdef: @@ -1140,8 +1141,10 @@ cdef class EdgeConnection(frontend.FrontendConnection): state_tid = self.buffer.read_bytes(16) state_data = self.buffer.read_len_prefixed_bytes() try: + # TODO add namespace self.get_dbview().decode_state(state_tid, state_data) except errors.StateMismatchError: + # TODO add namespace self.write(self.make_state_data_description_msg()) raise @@ -1274,6 +1277,7 @@ cdef class EdgeConnection(frontend.FrontendConnection): if self._cancelled: raise ConnectionAbortedError + # TODO add namespace if _dbview.is_state_desc_changed(): self.write(self.make_state_data_description_msg()) self.write( @@ -2032,7 +2036,7 @@ cdef class EdgeConnection(frontend.FrontendConnection): server = self.server compiler_pool = server.get_compiler_pool() - + # TODO add namespace global_schema = _dbview.get_global_schema() user_schema = _dbview.get_user_schema() @@ -2236,8 +2240,9 @@ cdef class EdgeConnection(frontend.FrontendConnection): self._in_dump_restore = False server.release_pgcon(dbname, pgcon) - await server.introspect_db(dbname) + await server.introspect(dbname) + # TODO add namespace if _dbview.is_state_desc_changed(): self.write(self.make_state_data_description_msg()) diff --git a/edb/server/protocol/binary_v0.pyx b/edb/server/protocol/binary_v0.pyx index f8bd843c5ed..e381d959da6 100644 --- a/edb/server/protocol/binary_v0.pyx +++ b/edb/server/protocol/binary_v0.pyx @@ -20,7 +20,7 @@ import asyncio cdef tuple MIN_LEGACY_PROTOCOL = edbdef.MIN_LEGACY_PROTOCOL -from edb.server import args as srvargs +from edb.server import args as srvargs, defines from edb.server.protocol cimport args_ser from edb.server.protocol import execute @@ -282,7 +282,6 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): compiler_pool = server.get_compiler_pool() dbname = _dbview.dbname - namespace = _dbview.namespace pgcon = await server.acquire_pgcon(dbname) self._in_dump_restore = True try: @@ -312,9 +311,9 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): SET statement_timeout = 0; ''', ) - - user_schema = await server.introspect_user_schema(dbname, namespace, pgcon) - global_schema = await server.introspect_global_schema(namespace, pgcon) + # TODO add namespace + user_schema = await server.introspect_user_schema(dbname, conn=pgcon) + global_schema = await server.introspect_global_schema(pgcon) db_config = await server.introspect_db_config(pgcon) dump_protocol = self.max_protocol @@ -456,9 +455,9 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): server = self.server compiler_pool = server.get_compiler_pool() - + # TODO add namespace global_schema = _dbview.get_global_schema() - user_schema = _dbview.get_user_schema() + user_schema = _dbview.get_user_schema(defines.DEFAULT_NS) dump_server_ver_str = None headers_num = self.buffer.read_int16() @@ -516,11 +515,11 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): self.buffer.finish_message() dbname = _dbview.dbname - namespace = _dbview.namespace pgcon = await server.acquire_pgcon(dbname) self._in_dump_restore = True try: + # TODO add namespace _dbview.decode_state(sertypes.NULL_TYPE_ID.bytes, b'') await self._execute_utility_stmt( 'START TRANSACTION ISOLATION SERIALIZABLE', @@ -661,7 +660,7 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): self._in_dump_restore = False server.release_pgcon(dbname, pgcon) - await server.introspect_db(dbname, namespace) + await server.introspect(dbname) msg = WriteBuffer.new_message(b'C') msg.write_int16(0) # no headers @@ -936,6 +935,7 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): bint inline_objectids = True bytes stmt_name = b'' str module = None + str namespace = None bint read_only = False headers = self.legacy_parse_headers() @@ -955,6 +955,8 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): module = v.decode() elif k == QUERY_HEADER_PROHIBIT_MUTATION: read_only = parse_boolean(v, "PROHIBIT_MUTATION") + elif k == QUERY_HEADER_EXPLICIT_NS: + namespace = v.decode() else: raise errors.BinaryProtocolError( f'unexpected message header: {k}' @@ -988,7 +990,8 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): inline_objectids=inline_objectids, allow_capabilities=allow_capabilities, module=module, - read_only=read_only + read_only=read_only, + namespace=namespace ) return eql, query_req, stmt_name @@ -1081,13 +1084,15 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): skip_first: bool, module: str = None, read_only: bool = False, + namespace: str = defines.DEFAULT_NS, ): query_req = dbview.QueryRequestInfo( source=edgeql.Source.from_string(query.decode("utf-8")), protocol_version=self.protocol_version, output_format=FMT_NONE, module=module, - read_only=read_only + read_only=read_only, + namespace=namespace, ) return await self.get_dbview()._compile( @@ -1140,6 +1145,7 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): module = None read_only = False + namespace = defines.DEFAULT_NS headers = self.legacy_parse_headers() if headers: for k, v in headers.items(): @@ -1147,6 +1153,8 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): allow_capabilities = parse_capabilities_header(v) elif k == QUERY_HEADER_EXPLICIT_MODULE: module = v.decode() + elif k == QUERY_HEADER_EXPLICIT_NS: + namespace = v.decode() elif k == QUERY_HEADER_PROHIBIT_MUTATION: read_only = parse_boolean(v, "PROHIBIT_MUTATION") else: @@ -1184,7 +1192,7 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): query_unit = await self._legacy_simple_query( eql, allow_capabilities, skip_first, - module, read_only) + module, read_only, namespace) packet = WriteBuffer.new() packet.write_buffer(self.make_legacy_command_complete_msg(query_unit)) @@ -1199,6 +1207,7 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): skip_first: bint, module: str = None, read_only: bool = False, + namespace: str = defines.DEFAULT_NS, ): cdef: bytes state = None, orig_state = None @@ -1207,7 +1216,9 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): pgcon.PGConnection conn unit_group = await self._legacy_compile_script( - eql, skip_first=skip_first, module=module, read_only=read_only) + eql, skip_first=skip_first, module=module, read_only=read_only, + namespace=namespace + ) if self._cancelled: raise ConnectionAbortedError @@ -1270,10 +1281,10 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): orig_state = None if query_unit.create_db: - await self.server.introspect_db(query_unit.create_db) + await self.server.introspect(query_unit.create_db) if query_unit.create_ns: - await self.server.introspect_db(_dbview.dbname, query_unit.create_ns) + await self.server.introspect(_dbview.dbname, query_unit.create_ns) if query_unit.drop_db: self.server._on_after_drop_db(query_unit.drop_db) @@ -1294,7 +1305,7 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): else: side_effects = _dbview.on_success(query_unit, new_types) if side_effects: - execute.signal_side_effects(_dbview, side_effects) + execute.signal_side_effects(_dbview, query_unit.namespace, side_effects) if not _dbview.in_tx(): state = _dbview.serialize_state() if state is not orig_state: diff --git a/edb/server/protocol/edgeql_ext.pyx b/edb/server/protocol/edgeql_ext.pyx index 58881a2060e..ea8b5086f5b 100644 --- a/edb/server/protocol/edgeql_ext.pyx +++ b/edb/server/protocol/edgeql_ext.pyx @@ -26,7 +26,7 @@ import immutables from edb import errors from edb import edgeql -from edb.server import defines as edbdef +from edb.server import defines as edbdef, defines from edb.server.protocol import execute from edb.common import debug @@ -64,6 +64,7 @@ async def handle_request( query = None module = None limit = 0 + namespace = defines.DEFAULT_NS try: if request.method == b'POST': @@ -76,6 +77,7 @@ async def handle_request( variables = body.get('variables') globals_ = body.get('globals') module = body.get('module') + namespace = body.get('namespace', defines.DEFAULT_NS) limit = body.get('limit', 0) else: raise TypeError( @@ -110,6 +112,12 @@ async def handle_request( if module is not None: module = module[0] + namespace = qs.get('namespace') + if namespace is not None: + namespace = namespace[0] + else: + namespace = defines.DEFAULT_NS + limit = qs.get('limit') if limit is not None: limit = int(limit[0]) @@ -130,6 +138,9 @@ async def handle_request( if module is not None and not isinstance(module, str): raise TypeError('"module" must be a str object') + if namespace is not None and not isinstance(namespace, str): + raise TypeError('"namespace" must be a str object') + if limit is not None and not isinstance(limit, int): raise TypeError('"limit" must be an integer object') @@ -147,6 +158,7 @@ async def handle_request( try: result = await execute.parse_execute_json( db, + namespace, query, variables=variables or {}, globals_=globals_ or {}, diff --git a/edb/server/protocol/execute.pyx b/edb/server/protocol/execute.pyx index ebcdbfb06e7..4df2ce6692c 100644 --- a/edb/server/protocol/execute.pyx +++ b/edb/server/protocol/execute.pyx @@ -85,7 +85,7 @@ async def execute( if query_unit.create_ns: await server.create_namespace(be_conn, query_unit.create_ns) if query_unit.drop_ns: - await server._on_before_drop_ns(query_unit.drop_ns, dbv.namespace) + await server._on_before_drop_ns(query_unit.drop_ns, query_unit.namespace) if query_unit.system_config: await execute_system_config(be_conn, dbv, query_unit) else: @@ -131,13 +131,14 @@ async def execute( if query_unit.tx_savepoint_declare: dbv.declare_savepoint( - query_unit.sp_name, query_unit.sp_id) + query_unit.namespace, query_unit.sp_name, query_unit.sp_id + ) if query_unit.create_db: - await server.introspect_db(query_unit.create_db) + await server.introspect(query_unit.create_db) if query_unit.create_ns: - await server.introspect_db(dbv.dbname, query_unit.create_ns) + await server.introspect(dbv.dbname, query_unit.create_ns) if query_unit.drop_db: server._on_after_drop_db(query_unit.drop_db) @@ -160,7 +161,7 @@ async def execute( else: side_effects = dbv.on_success(query_unit, new_types) if side_effects: - signal_side_effects(dbv, side_effects) + signal_side_effects(dbv, query_unit.namespace, side_effects) if not dbv.in_tx(): state = dbv.serialize_state() if state is not orig_state: @@ -287,7 +288,7 @@ async def execute_script( and query_unit.user_schema_mutation ): if user_schema_unpacked is None: - base_user_schema = user_schema or dbv.get_user_schema() + base_user_schema = user_schema or dbv.get_user_schema(query_unit.namespace) else: base_user_schema = user_schema_unpacked @@ -322,16 +323,16 @@ async def execute_script( gmut_unpickled = pickle.loads(group_mutation) side_effects = dbv.commit_implicit_tx( - user_schema, user_schema_unpacked, gmut_unpickled, + unit_group.namespace, user_schema, user_schema_unpacked, gmut_unpickled, global_schema, cached_reflection, unit_group.affected_obj_ids ) if side_effects: - signal_side_effects(dbv, side_effects) + signal_side_effects(dbv, query_unit.namespace, side_effects) if ( side_effects & dbview.SideEffects.SchemaChanges and group_mutation is not None ): - dbv.save_schema_mutation(gmut_unpickled, group_mutation) + dbv.save_schema_mutation(query_unit.namespace, gmut_unpickled, group_mutation) state = dbv.serialize_state() if state is not orig_state: @@ -378,7 +379,7 @@ async def execute_system_config( await conn.sql_execute(b'SELECT pg_reload_conf()') -def signal_side_effects(dbv, side_effects): +def signal_side_effects(dbv, namespace, side_effects): server = dbv.server if not server._accept_new_tasks: return @@ -388,7 +389,7 @@ def signal_side_effects(dbv, side_effects): server._signal_sysevent( 'schema-changes', dbname=dbv.dbname, - namespace=dbv.namespace, + namespace=namespace, ), interruptable=False, ) @@ -397,7 +398,6 @@ def signal_side_effects(dbv, side_effects): server.create_task( server._signal_sysevent( 'global-schema-changes', - namespace=dbv.namespace, ), interruptable=False, ) @@ -422,6 +422,7 @@ def signal_side_effects(dbv, side_effects): async def parse_execute( db: dbview.Database, + namespace: str, query: str, *, external_view: Mapping = immutables.Map(), @@ -431,8 +432,7 @@ async def parse_execute( dbv = await server.new_dbview( dbname=db.name, query_cache=False, - protocol_version=edbdef.CURRENT_PROTOCOL, - namespace=db.namespace + protocol_version=edbdef.CURRENT_PROTOCOL ) query_req = dbview.QueryRequestInfo( @@ -442,7 +442,8 @@ async def parse_execute( output_format=compiler.OutputFormat.NONE, allow_capabilities=compiler.Capability.MODIFICATIONS | compiler.Capability.DDL, external_view=external_view, - testmode=testmode + testmode=testmode, + namespace=namespace ) compiled = await dbv.parse(query_req) @@ -461,6 +462,7 @@ async def parse_execute( async def parse_execute_json( db: dbview.Database, + namespace: str, query: str, *, variables: Mapping[str, Any] = immutables.Map(), @@ -480,7 +482,6 @@ async def parse_execute_json( dbname=db.name, query_cache=query_cache_enabled, protocol_version=edbdef.CURRENT_PROTOCOL, - namespace=db.namespace ) allow_cap = compiler.Capability(0) if read_only else compiler.Capability.MODIFICATIONS @@ -493,6 +494,7 @@ async def parse_execute_json( allow_capabilities=allow_cap, read_only=read_only, module=module, + namespace=namespace, force_limit=limit ) diff --git a/edb/server/protocol/extern_obj.py b/edb/server/protocol/extern_obj.py index 972a1fb6b81..38fa540202c 100644 --- a/edb/server/protocol/extern_obj.py +++ b/edb/server/protocol/extern_obj.py @@ -25,6 +25,7 @@ from edb import errors +from edb.server import defines from edb.server.protocol import execute from edb.pgsql.types import base_type_name_map_r @@ -305,6 +306,7 @@ def _unknown_path(): try: if request.content_type and b'json' in request.content_type: body = json.loads(request.body) + namespace = body.pop('namespace', defines.DEFAULT_NS) if not isinstance(body, dict): raise TypeError( 'the body of the request must be a JSON object') @@ -334,6 +336,7 @@ def _unknown_path(): try: await execute.parse_execute( db, + namespace, req.to_ddl(), external_view=req.resolve_view(), testmode=bool(request.testmode) diff --git a/edb/server/protocol/infer_expr.py b/edb/server/protocol/infer_expr.py index c990b4509a1..9fd35fae996 100644 --- a/edb/server/protocol/infer_expr.py +++ b/edb/server/protocol/infer_expr.py @@ -23,6 +23,7 @@ from edb import errors from edb.common import debug from edb.common import markup +from edb.server import defines async def handle_request( @@ -52,6 +53,7 @@ async def handle_request( 'the body of the request must be a JSON object' ) module = body.get('module') + namespace = body.get('namespace', defines.DEFAULT_NS) objname = body.get('object') expr = body.get('expression') else: @@ -68,6 +70,8 @@ async def handle_request( if not isinstance(module, str): raise TypeError("Field 'module' must be a string.") + if not isinstance(namespace, str): + raise TypeError("Field 'namespace' must be a string.") if not isinstance(objname, str): raise TypeError("Field 'object' must be a string.") if not isinstance(expr, str): @@ -88,7 +92,7 @@ async def handle_request( await db.introspection() try: - result = await execute(db, server, module, objname, expr) + result = await execute(db, server, namespace, module, objname, expr) except Exception as ex: if debug.flags.server: markup.dump(ex) @@ -108,13 +112,13 @@ async def handle_request( response.body = json.dumps(result).encode() -async def execute(db, server, module: str, objname: str, expression: str): +async def execute(db, server, namespace: str, module: str, objname: str, expression: str): dbver = db.dbver query_cache = server._http_query_cache name_str = f"{module}::{objname}" - cache_key = ('infer_expr', name_str, expression, dbver, module) + cache_key = ('infer_expr', name_str, expression, dbver, module, namespace) entry = query_cache.get(cache_key, None) @@ -124,6 +128,7 @@ async def execute(db, server, module: str, objname: str, expression: str): compiler_pool = server.get_compiler_pool() result = await compiler_pool.infer_expr( db.name, + namespace, db.user_schema, server.get_global_schema(), db.reflection_cache, diff --git a/edb/server/protocol/notebook_ext.pyx b/edb/server/protocol/notebook_ext.pyx index 1cdccdad95a..fc1baa034dd 100644 --- a/edb/server/protocol/notebook_ext.pyx +++ b/edb/server/protocol/notebook_ext.pyx @@ -175,8 +175,7 @@ async def execute(db, server, queries: list): dbv = await server.new_dbview( dbname=db.name, query_cache=False, - protocol_version=edbdef.CURRENT_PROTOCOL, - namespace=db.namespace + protocol_version=edbdef.CURRENT_PROTOCOL ) bind_data = None diff --git a/edb/server/protocol/schema_info.py b/edb/server/protocol/schema_info.py index 7f3acdce8cd..8b0fe22d76e 100644 --- a/edb/server/protocol/schema_info.py +++ b/edb/server/protocol/schema_info.py @@ -25,6 +25,7 @@ from edb import errors from edb.common import debug from edb.common import markup +from edb.server import defines async def handle_request( @@ -41,6 +42,7 @@ async def handle_request( return query_uuid = None + namespace = defines.DEFAULT_NS try: if request.method == b'POST': @@ -50,6 +52,7 @@ async def handle_request( raise TypeError( 'the body of the request must be a JSON object') query_uuid = body.get('uuid') + namespace = body.get('namespace', defines.DEFAULT_NS) else: raise TypeError( 'unable to interpret SchemaInfo POST request') @@ -61,6 +64,11 @@ async def handle_request( query_uuid = qs.get('uuid') if query_uuid is not None: query_uuid = query_uuid[0] + namespace = qs.get('namespace') + if namespace is not None: + namespace = namespace[0] + else: + namespace = defines.DEFAULT_NS else: raise TypeError('expected a GET or a POST request') @@ -80,7 +88,7 @@ async def handle_request( response.content_type = b'application/json' await db.introspection() try: - result = await execute(db, server, query_uuid) + result = await execute(db, server, namespace, query_uuid) except Exception as ex: if debug.flags.server: markup.dump(ex) @@ -101,8 +109,12 @@ async def handle_request( response.body = b'{"data":' + result + b'}' -async def execute(db, server, query_uuid: str): - user_schema = db.user_schema +async def execute(db, server, namespace: str, query_uuid: str): + if namespace not in db.ns_map: + raise errors.InternalServerError( + f'NameSpace: [{namespace}] not in current db(ver:{db.dbver})' + ) + user_schema = db.ns_map[namespace].user_schema global_schema = server.get_global_schema() obj_id = uuid.UUID(query_uuid) diff --git a/edb/server/protocol/system_api.py b/edb/server/protocol/system_api.py index 4a590b73894..e674a1f048a 100644 --- a/edb/server/protocol/system_api.py +++ b/edb/server/protocol/system_api.py @@ -25,7 +25,7 @@ from edb.common import debug from edb.common import markup -from edb.server import compiler +from edb.server import compiler, defines from edb.server import defines as edbdef from . import execute # type: ignore @@ -90,6 +90,7 @@ async def handle_status_request( db = server.get_db(dbname=edbdef.EDGEDB_SYSTEM_DB) result = await execute.parse_execute_json( db, + defines.DEFAULT_NS, query="SELECT 'OK'", output_format=compiler.OutputFormat.JSON_ELEMENTS, # Disable query cache because we need to ensure that the compiled diff --git a/edb/server/server.py b/edb/server/server.py index 230079110d2..45eabba974f 100644 --- a/edb/server/server.py +++ b/edb/server/server.py @@ -417,7 +417,7 @@ async def init(self): await self._load_instance_data() - global_schema = await self.introspect_global_schema(defines.DEFAULT_NS) + global_schema = await self.introspect_global_schema() sys_config = await self.load_sys_config() await self.load_reported_config() @@ -541,19 +541,19 @@ def get_compiler_pool(self): def get_suggested_client_pool_size(self) -> int: return self._suggested_client_pool_size - def get_db(self, *, dbname: str, namespace: str = defines.DEFAULT_NS): + def get_db(self, *, dbname: str): assert self._dbindex is not None - return self._dbindex.get_db(dbname, namespace) + return self._dbindex.get_db(dbname) - def maybe_get_db(self, *, dbname: str, namespace: str = defines.DEFAULT_NS): + def maybe_get_db(self, *, dbname: str): assert self._dbindex is not None - return self._dbindex.maybe_get_db(dbname, namespace) + return self._dbindex.maybe_get_db(dbname) - async def new_dbview(self, *, dbname, query_cache, protocol_version, namespace: str = defines.DEFAULT_NS): - db = self.get_db(dbname=dbname, namespace=namespace) + async def new_dbview(self, *, dbname, query_cache, protocol_version): + db = self.get_db(dbname=dbname) await db.introspection() return self._dbindex.new_view( - dbname, namespace=namespace, query_cache=query_cache, protocol_version=protocol_version + dbname, query_cache=query_cache, protocol_version=protocol_version ) def remove_dbview(self, dbview): @@ -640,11 +640,8 @@ async def load_reported_config(self): finally: self._release_sys_pgcon() - async def introspect_global_schema(self, namespace, conn=None): - intro_query = _RE_BYTES_REPL_NS.sub( - namespace.encode('utf-8') + rb'_\1\2', - self._global_intro_query - ) + async def introspect_global_schema(self, conn=None): + intro_query = self._global_intro_query if conn is not None: json_data = await conn.sql_fetch_val(intro_query) else: @@ -661,22 +658,25 @@ async def introspect_global_schema(self, namespace, conn=None): schema_class_layout=self._schema_class_layout, ) - async def _reintrospect_global_schema(self, namespace): + async def _reintrospect_global_schema(self): if not self._initing and not self._serving: logger.warning( "global-schema-changes event received during shutdown; " "ignoring." ) return - new_global_schema = await self.introspect_global_schema(namespace) + new_global_schema = await self.introspect_global_schema() self._dbindex.update_global_schema(new_global_schema) self._fetch_roles() async def introspect_user_schema(self, dbname, namespace, conn): await self._persist_user_schema(dbname, namespace, conn) - + if namespace == defines.DEFAULT_NS: + ns_prefix = '' + else: + ns_prefix = namespace + '_' ns_intro_query = _RE_BYTES_REPL_NS.sub( - namespace.encode('utf-8') + rb'_\1\2', + ns_prefix.encode('utf-8') + rb'\1\2', self._local_intro_query ) json_data = await conn.sql_fetch_val(ns_intro_query) @@ -712,8 +712,8 @@ async def _acquire_intro_pgcon(self, dbname): raise return conn - async def introspect_db(self, dbname, namespace: str = None): - """Use this method to (re-)introspect a DB. + async def introspect(self, dbname, namespace: str = None): + """Use this method to (re-)introspect a DB or namespace. If the DB is already registered in self._dbindex, its schema, config, etc. would simply be updated. If it's missing @@ -742,11 +742,11 @@ async def introspect_db(self, dbname, namespace: str = None): ns_list = [namespace] for ns in ns_list: - await self.introspect_ns(conn, dbname, ns) + await self._introspect_ns(conn, dbname, ns) finally: self.release_pgcon(dbname, conn) - async def introspect_ns(self, conn, dbname, namespace): + async def _introspect_ns(self, conn, dbname, namespace): user_schema = await self.introspect_user_schema(dbname, namespace, conn) if namespace == defines.DEFAULT_NS: schema_name = 'edgedb' @@ -1304,7 +1304,7 @@ def _on_remote_ddl(self, dbname, namespace): # on the __edgedb_sysevent__ channel async def task(): try: - await self.introspect_db(dbname, namespace) + await self.introspect(dbname, namespace) except Exception: metrics.background_errors.inc(1.0, 'on_remote_ddl') raise @@ -1319,7 +1319,7 @@ def _on_remote_database_config_change(self, dbname): # on the __edgedb_sysevent__ channel async def task(): try: - await self.introspect_db(dbname) + await self.introspect(dbname) except Exception: metrics.background_errors.inc( 1.0, 'on_remote_database_config_change') @@ -1336,7 +1336,7 @@ def _on_local_database_config_change(self, dbname): # of the DB and update all components of it. async def task(): try: - await self.introspect_db(dbname) + await self.introspect(dbname) except Exception: metrics.background_errors.inc( 1.0, 'on_local_database_config_change') @@ -1367,7 +1367,7 @@ def _on_global_schema_change(self, namespace): async def task(ns): try: - await self._reintrospect_global_schema(ns) + await self._reintrospect_global_schema() except Exception: metrics.background_errors.inc( 1.0, 'on_global_schema_change') @@ -1999,31 +1999,33 @@ def serialize_config(cfg): ) dbs = {} - for db_name, ns_map in self._dbindex.iter_dbs(): - if db_name in defines.EDGEDB_SPECIAL_DBS: + for db in self._dbindex.iter_dbs(): + if db.name in defines.EDGEDB_SPECIAL_DBS: continue ns = {} - for ns_name, ns_db in ns_map.items(): + for ns_name, ns_db in db.ns_map.items(): ns[ns_name] = dict( - name=ns_db.name, namespace=ns_name, - dbver=ns_db.dbver, - config=serialize_config(ns_db.db_config), extensions=sorted(ns_db.extensions), - query_cache_size=ns_db.get_query_cache_size(), - connections=[ - dict( - in_tx=view.in_tx(), - in_tx_error=view.in_tx_error(), - config=serialize_config(view.get_session_config()), - module_aliases=view.get_modaliases(), - ) - for view in ns_db.iter_views() - ], + query_cache_size=ns_db.get_query_cache_size() ) - dbs[db_name] = dict(name=db_name, namespace=ns) + dbs[db.name] = dict( + name=db.name, + dbver=db.dbver, + config=serialize_config(db.db_config), + connections=[ + dict( + in_tx=view.in_tx(), + in_tx_error=view.in_tx_error(), + config=serialize_config(view.get_session_config()), + module_aliases=view.get_modaliases(), + ) + for view in db.iter_views() + ], + namespace=ns + ) obj['databases'] = dbs From ed3de3ca1db4692831a1df1d3f1b82b8b927b133 Mon Sep 17 00:00:00 2001 From: JMPRIEST Date: Tue, 23 May 2023 15:02:01 +0800 Subject: [PATCH 13/20] =?UTF-8?q?:construction:=20=E5=AE=8C=E6=88=90namesp?= =?UTF-8?q?ace=E5=9C=A8compile=E6=9C=9F=E9=97=B4=E7=9A=84=E8=B5=B7?= =?UTF-8?q?=E6=95=88=E9=80=BB=E8=BE=91=EF=BC=8Cddl&select&update&delete?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E9=80=9A=E8=BF=87=EF=BC=8C=E5=BE=85=E8=A1=A5?= =?UTF-8?q?namespace=E5=88=87=E6=8D=A2=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- edb/graphql/extension.pyx | 2 +- edb/pgsql/codegen.py | 9 +- edb/pgsql/common.py | 17 ++- edb/pgsql/compiler/__init__.py | 7 +- edb/pgsql/compiler/astutils.py | 4 +- edb/pgsql/compiler/config.py | 6 +- edb/pgsql/compiler/dml.py | 2 +- edb/pgsql/compiler/expr.py | 2 +- edb/pgsql/compiler/output.py | 2 +- edb/pgsql/compiler/relctx.py | 4 +- edb/pgsql/dbops/base.py | 18 ++- edb/pgsql/dbops/ddl.py | 15 +-- edb/pgsql/dbops/roles.py | 3 +- edb/pgsql/delta.py | 192 +++++++++++++---------------- edb/schema/delta.py | 9 +- edb/server/bootstrap.py | 6 +- edb/server/compiler/compiler.py | 29 ++--- edb/server/dbview/dbview.pxd | 4 +- edb/server/dbview/dbview.pyx | 14 +-- edb/server/pgcon/pgcon.pyx | 3 +- edb/server/protocol/binary.pxd | 1 + edb/server/protocol/binary.pyx | 34 ++--- edb/server/protocol/binary_v0.pyx | 25 ++-- edb/server/protocol/edgeql_ext.pyx | 8 +- edb/server/protocol/execute.pyx | 2 +- edb/server/protocol/schema_info.py | 2 +- edb/server/server.py | 52 ++++---- 27 files changed, 219 insertions(+), 253 deletions(-) diff --git a/edb/graphql/extension.pyx b/edb/graphql/extension.pyx index 017fe4c5b30..c1e3f2ce81e 100644 --- a/edb/graphql/extension.pyx +++ b/edb/graphql/extension.pyx @@ -262,7 +262,7 @@ async def _execute( if namespace not in db.ns_map: raise errors.InternalServerError( - f'NameSpace: [{namespace}] not in current db(ver:{db.dbver})' + f'NameSpace: [{namespace}] not in current db [{db.name}](ver:{db.dbver})' ) ns = db.ns_map[namespace] diff --git a/edb/pgsql/codegen.py b/edb/pgsql/codegen.py index abfc80bb8d5..de9ed46073f 100644 --- a/edb/pgsql/codegen.py +++ b/edb/pgsql/codegen.py @@ -63,11 +63,10 @@ def __init__(self, msg, *, node=None, details=None, hint=None): class SQLSourceGenerator(codegen.SourceGenerator): - def __init__(self, *args, reordered: bool=False, namespace: str = defines.DEFAULT_NS, **kwargs): + def __init__(self, *args, reordered: bool=False, **kwargs): super().__init__(*args, **kwargs) self.param_index: dict[object, int] = {} self.reordered = reordered - self.namespace = namespace @classmethod def to_source( @@ -120,10 +119,8 @@ def visit_Relation(self, node): if node.schemaname is None: self.write(common.qname(node.name)) else: - if self.namespace == defines.DEFAULT_NS: - self.write(common.qname(node.schemaname, node.name)) - elif node.schemaname in defines.EDGEDB_OWNED_DBS: - self.write(common.qname(f"{self.namespace}_{node.schemaname}", node.name)) + if node.schemaname in defines.EDGEDB_OWNED_DBS: + self.write(common.qname(common.actual_schemaname(node.schemaname), node.name)) else: self.write(common.qname(node.schemaname, node.name)) diff --git a/edb/pgsql/common.py b/edb/pgsql/common.py index f63aa93bec5..36c31637871 100644 --- a/edb/pgsql/common.py +++ b/edb/pgsql/common.py @@ -26,7 +26,7 @@ import re from edb.common import uuidgen -from edb.schema import abc as s_abc +from edb.schema import abc as s_abc, defines from edb.schema import casts as s_casts from edb.schema import constraints as s_constr from edb.schema import defines as s_def @@ -45,6 +45,7 @@ RE_LINK_TRIGGER = re.compile(r'(source|target)-del-(def|imm)-(inl|otl)-(f|t)') RE_DUNDER_TYPE_LINK_TRIGGER = re.compile(r'dunder-type-link-[ft]') +NAMESPACE = defines.DEFAULT_NS def quote_e_literal(string): @@ -129,7 +130,19 @@ def quote_type(type_): def get_module_backend_name(module: s_name.Name) -> str: # standard modules go into "edgedbstd", user ones into "edgedbpub" - return "edgedbstd" if module in s_schema.STD_MODULES else "edgedbpub" + return actual_schemaname("edgedbstd") if module in s_schema.STD_MODULES else actual_schemaname("edgedbpub") + + +def actual_schemaname(name: str) -> str: + global NAMESPACE + if name not in defines.EDGEDB_OWNED_DBS: + return name + + if NAMESPACE == defines.DEFAULT_NS: + ns_prefix = '' + else: + ns_prefix = NAMESPACE + '_' + return f"{ns_prefix}{name}" def get_unique_random_name() -> str: diff --git a/edb/pgsql/compiler/__init__.py b/edb/pgsql/compiler/__init__.py index d54eb109351..c5d479ed86b 100644 --- a/edb/pgsql/compiler/__init__.py +++ b/edb/pgsql/compiler/__init__.py @@ -27,7 +27,6 @@ from edb.common import exceptions as edgedb_error from edb.ir import ast as irast -from edb.schema import defines from edb.pgsql import ast as pgast from edb.pgsql import codegen as pgcodegen @@ -148,7 +147,6 @@ def compile_ir_to_sql( expected_cardinality_one: bool=False, pretty: bool=True, backend_runtime_params: Optional[pgparams.BackendRuntimeParams]=None, - namespace: str = defines.DEFAULT_NS ) -> Tuple[str, Dict[str, pgast.Param]]: qtree = compile_ir_to_sql_tree( @@ -174,7 +172,7 @@ def compile_ir_to_sql( argmap = {} # Generate query text - sql_text = run_codegen(qtree, pretty=pretty, namespace=namespace) + sql_text = run_codegen(qtree, pretty=pretty) if ( # pragma: no cover debug.flags.edgeql_compile or debug.flags.edgeql_compile_sql_text @@ -196,9 +194,8 @@ def run_codegen( *, pretty: bool=True, reordered: bool=False, - namespace: str = defines.DEFAULT_NS ) -> str: - codegen = pgcodegen.SQLSourceGenerator(pretty=pretty, reordered=reordered, namespace=namespace) + codegen = pgcodegen.SQLSourceGenerator(pretty=pretty, reordered=reordered) try: codegen.visit(qtree) except pgcodegen.SQLSourceGeneratorError as e: # pragma: no cover diff --git a/edb/pgsql/compiler/astutils.py b/edb/pgsql/compiler/astutils.py index 15af6760b91..8aa9db5b326 100644 --- a/edb/pgsql/compiler/astutils.py +++ b/edb/pgsql/compiler/astutils.py @@ -26,7 +26,7 @@ from edb.ir import typeutils as irtyputils -from edb.pgsql import ast as pgast +from edb.pgsql import ast as pgast, common from edb.pgsql import types as pg_types if TYPE_CHECKING: @@ -234,7 +234,7 @@ def safe_array_expr( ) if any(el.nullable for el in elements): result = pgast.FuncCall( - name=('edgedb', '_nullif_array_nulls'), + name=(common.actual_schemaname('edgedb'), '_nullif_array_nulls'), args=[result], ser_safe=ser_safe, ) diff --git a/edb/pgsql/compiler/config.py b/edb/pgsql/compiler/config.py index 98a6fe72eb6..03a8d086d71 100644 --- a/edb/pgsql/compiler/config.py +++ b/edb/pgsql/compiler/config.py @@ -66,7 +66,7 @@ def compile_ConfigSet( ) fcall = pgast.FuncCall( - name=('edgedb', '_alter_current_database_set'), + name=(common.actual_schemaname('edgedb'), '_alter_current_database_set'), args=[pgast.StringConstant(val=op.backend_setting), val], ) @@ -257,7 +257,7 @@ def compile_ConfigReset( elif op.scope is qltypes.ConfigScope.DATABASE and op.backend_setting: fcall = pgast.FuncCall( - name=('edgedb', '_alter_current_database_set'), + name=(common.actual_schemaname('edgedb'), '_alter_current_database_set'), args=[ pgast.StringConstant(val=op.backend_setting), pgast.NullConstant(), @@ -412,7 +412,7 @@ def _rewrite_config_insert( overwrite_query = pgast.SelectStmt() id_expr = pgast.FuncCall( - name=('edgedbext', 'uuid_generate_v1mc',), + name=(common.actual_schemaname('edgedbext'), 'uuid_generate_v1mc',), args=[], ) pathctx.put_path_identity_var( diff --git a/edb/pgsql/compiler/dml.py b/edb/pgsql/compiler/dml.py index 981ab348155..f35ad60c250 100644 --- a/edb/pgsql/compiler/dml.py +++ b/edb/pgsql/compiler/dml.py @@ -1564,7 +1564,7 @@ def check_update_type( # also the (dynamic) type of the argument, so that we can produce # a good error message. check_result = pgast.FuncCall( - name=('edgedb', 'issubclass'), + name=(common.actual_schemaname('edgedb'), 'issubclass'), args=[typ, typeref_val], ) maybe_null = pgast.CaseExpr( diff --git a/edb/pgsql/compiler/expr.py b/edb/pgsql/compiler/expr.py index 3314443ea7c..65c96f22636 100644 --- a/edb/pgsql/compiler/expr.py +++ b/edb/pgsql/compiler/expr.py @@ -560,7 +560,7 @@ def compile_TypeCheckOp( right = dispatch.compile(expr.right, ctx=newctx) result = pgast.FuncCall( - name=('edgedb', 'issubclass'), + name=(common.actual_schemaname('edgedb'), 'issubclass'), args=[left, right]) if negated: diff --git a/edb/pgsql/compiler/output.py b/edb/pgsql/compiler/output.py index 37b5d603df3..e607bc2be8e 100644 --- a/edb/pgsql/compiler/output.py +++ b/edb/pgsql/compiler/output.py @@ -560,7 +560,7 @@ def serialize_expr_to_json( elif irtyputils.is_range(styperef) and not expr.ser_safe: val = pgast.FuncCall( # Use the actual generic helper for converting anyrange to jsonb - name=('edgedb', 'range_to_jsonb'), + name=(common.actual_schemaname('edgedb'), 'range_to_jsonb'), args=[expr], null_safe=True, ser_safe=True) elif irtyputils.is_collection(styperef) and not expr.ser_safe: diff --git a/edb/pgsql/compiler/relctx.py b/edb/pgsql/compiler/relctx.py index b01e3c9b289..552e5325322 100644 --- a/edb/pgsql/compiler/relctx.py +++ b/edb/pgsql/compiler/relctx.py @@ -441,7 +441,7 @@ def new_free_object_rvar( qry = subctx.rel id_expr = pgast.FuncCall( - name=('edgedbext', 'uuid_generate_v4',), args=[] + name=(common.actual_schemaname('edgedbext'), 'uuid_generate_v4',), args=[] ) pathctx.put_path_identity_var(qry, path_id, id_expr, env=ctx.env) @@ -824,7 +824,7 @@ def ensure_transient_identity_for_path( ) -> None: id_expr = pgast.FuncCall( - name=('edgedbext', 'uuid_generate_v4',), + name=(common.actual_schemaname('edgedbext'), 'uuid_generate_v4',), args=[], ) diff --git a/edb/pgsql/dbops/base.py b/edb/pgsql/dbops/base.py index 50bcd2e4928..41047ede2ee 100644 --- a/edb/pgsql/dbops/base.py +++ b/edb/pgsql/dbops/base.py @@ -72,13 +72,12 @@ class PLExpression(str): class SQLBlock: - def __init__(self, namespace_prefix: str = ''): + def __init__(self): self.commands = [] self._transactional = True - self.namespace_prefix = namespace_prefix def add_block(self): - block = PLTopBlock(self.namespace_prefix) + block = PLTopBlock() self.add_command(block) return block @@ -113,11 +112,8 @@ def is_transactional(self) -> bool: class PLBlock(SQLBlock): - def __init__(self, top_block, level, namespace_prefix: str = ''): - if top_block is None: - super().__init__(namespace_prefix) - else: - super().__init__(top_block.namespace_prefix) + def __init__(self, top_block, level): + super().__init__() self.top_block = top_block self.varcounter = collections.defaultdict(int) self.shared_vars = set() @@ -136,7 +132,7 @@ def get_top_block(self) -> PLTopBlock: return self.top_block def add_block(self): - block = PLBlock(top_block=self.top_block, level=self.level + 1, namespace_prefix=self.namespace_prefix) + block = PLBlock(top_block=self.top_block, level=self.level + 1) self.add_command(block) return block @@ -248,8 +244,8 @@ def declare_var( class PLTopBlock(PLBlock): - def __init__(self, namespace_prefix: str = ''): - super().__init__(top_block=None, level=0, namespace_prefix=namespace_prefix) + def __init__(self): + super().__init__(top_block=None, level=0) self.declare_var('text', var_name='_dummy_text', shared=True) def add_block(self): diff --git a/edb/pgsql/dbops/ddl.py b/edb/pgsql/dbops/ddl.py index cd27142d91b..5d092e5c1ec 100644 --- a/edb/pgsql/dbops/ddl.py +++ b/edb/pgsql/dbops/ddl.py @@ -27,6 +27,7 @@ from ..common import quote_ident as qi from ..common import quote_literal as ql +from ..common import actual_schemaname as actual from . import base @@ -115,7 +116,7 @@ def code(self, block: base.PLBlock) -> str: if is_shared: return textwrap.dedent(f'''\ SELECT - {block.namespace_prefix}edgedb.shobj_metadata( + {actual("edgedb")}.shobj_metadata( {objoid}, {classoid}::regclass::text ) @@ -123,7 +124,7 @@ def code(self, block: base.PLBlock) -> str: elif objsubid: return textwrap.dedent(f'''\ SELECT - {block.namespace_prefix}edgedb.col_metadata( + {actual("edgedb")}.col_metadata( {objoid}, {objsubid} ) @@ -131,7 +132,7 @@ def code(self, block: base.PLBlock) -> str: else: return textwrap.dedent(f'''\ SELECT - {block.namespace_prefix}edgedb.obj_metadata( + {actual("edgedb")}.obj_metadata( {objoid}, {classoid}::regclass::text, ) @@ -149,7 +150,7 @@ def code(self, block: base.PLBlock) -> str: SELECT json FROM - {block.namespace_prefix}edgedbinstdata.instdata + {actual("edgedbinstdata")}.instdata WHERE key = {ql(key)} ''') @@ -211,7 +212,7 @@ def code(self, block: base.PLBlock) -> str: metadata = ql(json.dumps(self.metadata)) return textwrap.dedent(f'''\ UPDATE - {block.namespace_prefix}edgedbinstdata.instdata + {actual("edgedbinstdata")}.instdata SET json = {metadata} WHERE @@ -260,7 +261,7 @@ def code(self, block: base.PLBlock) -> str: return textwrap.dedent(f'''\ UPDATE - {block.namespace_prefix}edgedbinstdata.instdata + {actual("edgedbinstdata")}.instdata SET json = {json_v} || {meta_v} WHERE @@ -329,7 +330,7 @@ def code(self, block: base.PLBlock) -> str: json_v, meta_v = self._merge(block) return textwrap.dedent(f'''\ UPDATE - {block.namespace_prefix}edgedbinstdata.instdata + {actual("edgedbinstdata")}.instdata SET json = {json_v} || {meta_v} WHERE diff --git a/edb/pgsql/dbops/roles.py b/edb/pgsql/dbops/roles.py index ae7fa6c70d2..555a45a1994 100644 --- a/edb/pgsql/dbops/roles.py +++ b/edb/pgsql/dbops/roles.py @@ -25,6 +25,7 @@ from ..common import quote_ident as qi from ..common import quote_literal as ql +from ..common import actual_schemaname as actual from . import base from . import ddl @@ -153,7 +154,7 @@ def generate_extra(self, block: base.PLBlock) -> None: value = json.dumps(self.object.single_role_metadata) query = base.Query( f''' - UPDATE {block.namespace_prefix}edgedbinstdata.instdata + UPDATE {actual("edgedbinstdata")}.instdata SET json = {ql(value)}::jsonb WHERE key = 'single_role_metadata' ''' diff --git a/edb/pgsql/delta.py b/edb/pgsql/delta.py index 296053b8273..aa0b5dd54d3 100644 --- a/edb/pgsql/delta.py +++ b/edb/pgsql/delta.py @@ -34,7 +34,7 @@ from edb.edgeql import qltypes as ql_ft from edb.edgeql import compiler as qlcompiler -from edb.schema import annos as s_anno, defines +from edb.schema import annos as s_anno from edb.schema import casts as s_casts from edb.schema import scalars as s_scalars from edb.schema import objtypes as s_objtypes @@ -362,7 +362,7 @@ def apply( (SELECT version::text FROM - {context.pg_schema('edgedb')}."_SchemaSchemaVersion" + {common.actual_schemaname('edgedb')}."_SchemaSchemaVersion" FOR UPDATE), {ql(str(expected_ver))} )), @@ -370,7 +370,7 @@ def apply( msg => ( 'Cannot serialize DDL: ' || (SELECT version::text FROM - {context.pg_schema('edgedb')}."_SchemaSchemaVersion") + {common.actual_schemaname('edgedb')}."_SchemaSchemaVersion") ) ) INTO _dummy_text @@ -467,7 +467,7 @@ def apply( SELECT json FROM - {context.pg_schema('edgedbinstdata')}.instdata + edgedbinstdata.instdata WHERE key = {ql(key)} FOR UPDATE @@ -520,7 +520,7 @@ def apply( msg => ( 'Cannot serialize global DDL: ' || (SELECT version::text FROM - {context.pg_schema('edgedb')}."_SysGlobalSchemaVersion") + edgedb."_SysGlobalSchemaVersion") ) ) INTO _dummy_text @@ -537,7 +537,7 @@ def apply( (SELECT version::text FROM - {context.pg_schema('edgedb')}."_SysGlobalSchemaVersion" + edgedb."_SysGlobalSchemaVersion" ), {ql(str(expected_ver))} )), @@ -545,7 +545,7 @@ def apply( msg => ( 'Cannot serialize global DDL: ' || (SELECT version::text FROM - {context.pg_schema('edgedb')}."_SysGlobalSchemaVersion") + edgedb."_SysGlobalSchemaVersion") ) ) INTO _dummy_text @@ -1171,7 +1171,7 @@ def compile_edgeql_overloaded_function_body( target AS ancestor, index FROM - {context.pg_schema('edgedb')}."_SchemaObjectType__ancestors" + {common.actual_schemaname('edgedb')}."_SchemaObjectType__ancestors" WHERE source = {qi(type_param_name)} ) a WHERE ancestor IN ({impl_ids}) @@ -3541,7 +3541,7 @@ def apply( self.apply_scheduled_inhview_updates(schema, context) if is_external: - schema_name = context.pg_schema('edgedbpub') + schema_name = common.actual_schemaname('edgedbpub') view_name = (schema_name, str(objtype.id)) view_name_t = (schema_name, str(objtype.id) + '_t') self.pgops.add( @@ -5161,7 +5161,7 @@ def apply( self.apply_scheduled_inhview_updates(schema, context) if has_extern_table: - schema_name = context.pg_schema('edgedbpub') + schema_name = common.actual_schemaname('edgedbpub') view_name = (schema_name, str(link.id)) view_name_t = (schema_name, str(link.id) + '_t') self.pgops.add( @@ -5801,58 +5801,57 @@ def get_trigger_proc_name(self, schema, target, return common.get_backend_name( schema, target, catenate=False, aspect=aspect) - def get_trigger_proc_text( - self, target, links, *, disposition, inline, schema, namespace - ): + def get_trigger_proc_text(self, target, links, *, + disposition, inline, schema): if inline: return self._get_inline_link_trigger_proc_text( - target, links, disposition=disposition, schema=schema, namespace=namespace - ) + target, links, disposition=disposition, schema=schema) else: return self._get_outline_link_trigger_proc_text( - target, links, disposition=disposition, schema=schema, namespace=namespace + target, links, disposition=disposition, schema=schema) + + def _get_dunder_type_trigger_proc_text(self, target, *, schema): + body = textwrap.dedent( + '''SELECT + CASE WHEN tp.builtin + THEN '{std}' + ELSE '{pub}' + END AS sname + INTO schema_name + FROM {edb}."_SchemaType" as tp + WHERE tp.id = OLD.id; + + SELECT EXISTS ( + SELECT FROM pg_tables + WHERE schemaname = "schema_name" + AND tablename = OLD.id::text + ) INTO table_exists; + + IF table_exists THEN + target_sql = format('SELECT EXISTS (SELECT FROM %I.%I LIMIT 1)', "schema_name", OLD.id::text); + EXECUTE target_sql into del_prohibited; + ELSE + del_prohibited = FALSE; + END IF; + + IF del_prohibited THEN + RAISE foreign_key_violation + USING + TABLE = TG_TABLE_NAME, + SCHEMA = TG_TABLE_SCHEMA, + MESSAGE = 'deletion of {tgtname} (' || OLD.id + || ') is prohibited by link target policy', + DETAIL = 'Object is still referenced in link __type__' + || ' of ' || {edb}._get_schema_object_name(OLD.id) || ' (' + || OLD.id || ').'; + END IF; + '''.format( + tgtname=target.get_displayname(schema), + std=common.actual_schemaname('edgedbstd'), + pub=common.actual_schemaname('edgedbpub'), + edb=common.actual_schemaname('edgedb') ) - - def _get_dunder_type_trigger_proc_text(self, target, *, schema, namespace): - if namespace == defines.DEFAULT_NS: - ns_prefix = '' - else: - ns_prefix = namespace + '_' - body = textwrap.dedent('''\ - SELECT - CASE WHEN tp.builtin - THEN '{ns_prefix}edgedbstd' - ELSE '{ns_prefix}edgedbpub' - END AS sname - INTO schema_name - FROM {ns_prefix}edgedb."_SchemaType" as tp - WHERE tp.id = OLD.id; - - SELECT EXISTS ( - SELECT FROM pg_tables - WHERE schemaname = "schema_name" - AND tablename = OLD.id::text - ) INTO table_exists; - - IF table_exists THEN - target_sql = format('SELECT EXISTS (SELECT FROM %I.%I LIMIT 1)', "schema_name", OLD.id::text); - EXECUTE target_sql into del_prohibited; - ELSE - del_prohibited = FALSE; - END IF; - - IF del_prohibited THEN - RAISE foreign_key_violation - USING - TABLE = TG_TABLE_NAME, - SCHEMA = TG_TABLE_SCHEMA, - MESSAGE = 'deletion of {tgtname} (' || OLD.id - || ') is prohibited by link target policy', - DETAIL = 'Object is still referenced in link __type__' - || ' of ' || {ns_prefix}edgedb._get_schema_object_name(OLD.id) || ' (' - || OLD.id || ').'; - END IF; - '''.format(tgtname=target.get_displayname(schema), ns_prefix=ns_prefix)) + ) text = textwrap.dedent('''\ DECLARE @@ -5868,12 +5867,7 @@ def _get_dunder_type_trigger_proc_text(self, target, *, schema, namespace): return text def _get_outline_link_trigger_proc_text( - self, target, links, *, disposition, schema, namespace - ): - if namespace == defines.DEFAULT_NS: - ns_prefix = '' - else: - ns_prefix = namespace + '_' + self, target, links, *, disposition, schema): chunks = [] @@ -5932,11 +5926,11 @@ def _declare_var(var_prefix, index, var_type): IF FOUND THEN SELECT - {ns_prefix}edgedb.shortname_from_fullname(link.name), - {ns_prefix}edgedb._get_schema_object_name(link.{far_endpoint}) + {edb}.shortname_from_fullname(link.name), + {edb}._get_schema_object_name(link.{far_endpoint}) INTO linkname, endname FROM - {ns_prefix}edgedb."_SchemaLink" AS link + {edb}."_SchemaLink" AS link WHERE link.id = link_type_id; RAISE foreign_key_violation @@ -5957,7 +5951,7 @@ def _declare_var(var_prefix, index, var_type): tgtname=target.get_displayname(schema), near_endpoint=near_endpoint, far_endpoint=far_endpoint, - ns_prefix=ns_prefix + edb=common.actual_schemaname('edgedb') ) chunks.append(text) @@ -6200,12 +6194,7 @@ def _resolve_link_group( return tuple(group_var) def _get_inline_link_trigger_proc_text( - self, target, links, *, disposition, schema, namespace - ): - if namespace == defines.DEFAULT_NS: - ns_prefix = '' - else: - ns_prefix = namespace + '_' + self, target, links, *, disposition, schema): chunks = [] @@ -6258,11 +6247,11 @@ def _get_inline_link_trigger_proc_text( IF FOUND THEN SELECT - {ns_prefix}edgedb.shortname_from_fullname(link.name), - {ns_prefix}edgedb._get_schema_object_name(link.{far_endpoint}) + {edb}.shortname_from_fullname(link.name), + {edb}._get_schema_object_name(link.{far_endpoint}) INTO linkname, endname FROM - {ns_prefix}edgedb."_SchemaLink" AS link + {edb}."_SchemaLink" AS link WHERE link.id = link_type_id; RAISE foreign_key_violation @@ -6283,7 +6272,7 @@ def _get_inline_link_trigger_proc_text( tgtname=target.get_displayname(schema), near_endpoint=near_endpoint, far_endpoint=far_endpoint, - ns_prefix=ns_prefix + edb=common.actual_schemaname('edgedb') ) chunks.append(text) @@ -6549,14 +6538,12 @@ def apply( if links or modifications: self._update_action_triggers( - schema, source, links, disposition='source', namespace=context.namespace - ) + schema, source, links, disposition='source') if inline_links or modifications: self._update_action_triggers( schema, source, inline_links, - inline=True, disposition='source', namespace=context.namespace - ) + inline=True, disposition='source') # All descendants of affected targets also need to have their # triggers updated, so track them down. @@ -6643,31 +6630,27 @@ def apply( key=lambda l: l.get_name(schema)) if dunder_type_links: - self._update_dunder_type_link_triggers(schema, target, context.namespace) + self._update_dunder_type_link_triggers(schema, target) if links or modifications: self._update_action_triggers( - schema, target, links, disposition='target', namespace=context.namespace - ) + schema, target, links, disposition='target') if inline_links or modifications: self._update_action_triggers( schema, target, inline_links, - disposition='target', inline=True, namespace=context.namespace - ) + disposition='target', inline=True) if deferred_links or modifications: self._update_action_triggers( schema, target, deferred_links, - disposition='target', deferred=True, namespace=context.namespace - ) + disposition='target', deferred=True) if deferred_inline_links or modifications: self._update_action_triggers( schema, target, deferred_inline_links, disposition='target', deferred=True, - inline=True, namespace=context.namespace - ) + inline=True) return schema @@ -6675,7 +6658,6 @@ def _update_dunder_type_link_triggers( self, schema, objtype: s_objtypes.ObjectType, - namespace: str ) -> None: table_name = common.get_backend_name( schema, objtype, catenate=False) @@ -6692,8 +6674,7 @@ def _update_dunder_type_link_triggers( is_constraint=True, inherit=True, deferred=False) proc_text = self._get_dunder_type_trigger_proc_text( - objtype, schema=schema, namespace=namespace - ) + objtype, schema=schema) trig_func = dbops.Function( name=proc_name, text=proc_text, volatility='volatile', @@ -6708,15 +6689,13 @@ def _update_dunder_type_link_triggers( )) def _update_action_triggers( - self, - schema, - objtype: s_objtypes.ObjectType, - links: List[s_links.Link], *, - disposition: str, - namespace: str, - deferred: bool = False, - inline: bool = False, - ) -> None: + self, + schema, + objtype: s_objtypes.ObjectType, + links: List[s_links.Link], *, + disposition: str, + deferred: bool=False, + inline: bool=False) -> None: table_name = common.get_backend_name( schema, objtype, catenate=False) @@ -6737,8 +6716,7 @@ def _update_action_triggers( if links: proc_text = self.get_trigger_proc_text( objtype, links, disposition=disposition, - inline=inline, schema=schema, namespace=namespace - ) + inline=inline, schema=schema) trig_func = dbops.Function( name=proc_name, text=proc_text, volatility='volatile', @@ -6795,7 +6773,7 @@ def collect_external_objects( view_def = context.external_view[key] if context.restoring_external: - schema_name = context.pg_schema('edgedbpub') + schema_name = common.actual_schemaname('edgedbpub') self.external_views.append(dbops.View(query=view_def, name=(schema_name, str(obj.id)))) self.external_views.append(dbops.View(query=view_def, name=(schema_name, str(obj.id) + '_t'))) return @@ -6818,7 +6796,7 @@ def collect_external_objects( ptrname = ptr.get_shortname(schema).name if ptrname == 'id': - columns.append(f"{context.pg_schema('edgedbext')}.uuid_generate_v1mc() AS id") + columns.append(f"{common.actual_schemaname('edgedbext')}.uuid_generate_v1mc() AS id") elif ptrname == '__type__': columns.append(f"'{(str(obj.id))}'::uuid AS __type__") elif has_link_table: @@ -6847,7 +6825,7 @@ def collect_external_objects( if join_link_table is not None: query += f", (SELECT * FROM {join_link_table.relation}) AS INNER_T " \ f"where INNER_T.{join_link_table.columns['source']} = SOURCE_T.{source_identity}" - schema_name = context.pg_schema('edgedbpub') + schema_name = common.actual_schemaname('edgedbpub') self.external_views.append(dbops.View(query=query, name=(schema_name, str(obj.id)))) self.external_views.append(dbops.View(query=query, name=(schema_name, str(obj.id) + '_t'))) diff --git a/edb/schema/delta.py b/edb/schema/delta.py index bc675d23b86..1c7f035210f 100644 --- a/edb/schema/delta.py +++ b/edb/schema/delta.py @@ -43,7 +43,7 @@ from edb.edgeql import compiler as qlcompiler from edb.edgeql import qltypes -from . import expr as s_expr, defines +from . import expr as s_expr from . import name as sn from . import objects as so from . import schema as s_schema @@ -1219,7 +1219,6 @@ def __init__( module_is_implicit: Optional[bool] = False, external_view: Optional[Mapping] = None, restoring_external: Optional[bool] = False, - namespace: str = defines.DEFAULT_NS, ) -> None: self.stack: List[CommandContextToken[Command]] = [] self._cache: Dict[Hashable, Any] = {} @@ -1251,7 +1250,6 @@ def __init__( self.external_view = external_view or immutables.Map() self.restoring_external = restoring_external self.external_objs = set() - self.namespace = namespace @property def modaliases(self) -> Mapping[Optional[str], str]: @@ -1505,11 +1503,6 @@ def compat_ver_is_before( ) -> bool: return self.compat_ver is not None and self.compat_ver < ver - def pg_schema(self, schema_name: str): - if self.namespace == defines.DEFAULT_NS: - return schema_name - return f"{self.namespace}_{schema_name}" - class ContextStack: diff --git a/edb/server/bootstrap.py b/edb/server/bootstrap.py index 8c5524494d1..8914e1906ec 100644 --- a/edb/server/bootstrap.py +++ b/edb/server/bootstrap.py @@ -1018,7 +1018,7 @@ async def gen_tpl_dump(cluster: pgcluster.BaseCluster): exclude_schemas=['edgedbext'], dump_object_owners=False, ) - commands = [dbops.CreateSchema(name='{ns_edgedbext}')] + commands = [dbops.CreateSchema(name='{ns_prefix}edgedbext')] for uuid_func in [ 'uuid_generate_v1', 'uuid_generate_v1mc', @@ -1032,7 +1032,7 @@ async def gen_tpl_dump(cluster: pgcluster.BaseCluster): commands.append( dbops.CreateOrReplaceFunction( dbops.Function( - name=('{ns_edgedbext}', uuid_func), + name=('{ns_prefix}edgedbext', uuid_func), returns=('pg_catalog', 'uuid'), language='plpgsql', text=f""" BEGIN @@ -1047,7 +1047,7 @@ async def gen_tpl_dump(cluster: pgcluster.BaseCluster): commands.append( dbops.CreateOrReplaceFunction( dbops.Function( - name=('{ns_edgedbext}', uuid_func), + name=('{ns_prefix}edgedbext', uuid_func), returns=('pg_catalog', 'uuid'), language='plpgsql', args=[('namespace', 'uuid'), ('name', 'text')], text=f""" diff --git a/edb/server/compiler/compiler.py b/edb/server/compiler/compiler.py index 7c474bcdb0c..f65da65d53b 100644 --- a/edb/server/compiler/compiler.py +++ b/edb/server/compiler/compiler.py @@ -391,7 +391,6 @@ def _new_delta_context(self, ctx: CompileContext): context.module = ctx.module context.external_view = ctx.external_view context.restoring_external = ctx.restoring_external - context.namespace = ctx.namespace return context def _process_delta(self, ctx: CompileContext, delta): @@ -423,16 +422,12 @@ def _process_delta(self, ctx: CompileContext, delta): isinstance(c, s_ns.NameSpaceCommand) for c in pgdelta.get_subcommands() ) - if ctx.namespace == defines.DEFAULT_NS: - ns_prefix = '' - else: - ns_prefix = ctx.namespace + '_' if db_cmd or ns_cmd: block = pg_dbops.SQLBlock() new_be_types = new_types = frozenset() else: - block = pg_dbops.PLTopBlock(ns_prefix) + block = pg_dbops.PLTopBlock() def may_has_backend_id(_id): obj = schema.get_by_id(_id, None) @@ -468,7 +463,7 @@ def may_has_backend_id(_id): # Generate schema storage SQL (DML into schema storage tables). if schema_peristence_async: - refl_block = pg_dbops.PLTopBlock(ns_prefix) + refl_block = pg_dbops.PLTopBlock() else: refl_block = None @@ -476,7 +471,7 @@ def may_has_backend_id(_id): ctx, pgdelta, subblock, context=context, schema_persist_block=refl_block ) - instdata_schemaname = f"{ns_prefix}edgedbinstdata" + instdata_schemaname = pg_common.actual_schemaname("edgedbinstdata") if schema_peristence_async: if debug.flags.keep_schema_persistence_history: invalid_persist_his = f"""\ @@ -500,7 +495,7 @@ def may_has_backend_id(_id): """)) if pgdelta.std_inhview_updates: - stdview_block = pg_dbops.PLTopBlock(ns_prefix) + stdview_block = pg_dbops.PLTopBlock() pgdelta.generate_std_inhview(stdview_block) else: stdview_block = None @@ -544,15 +539,10 @@ def _compile_schema_storage_in_delta( cache = current_tx.get_cached_reflection() - if ctx.namespace == defines.DEFAULT_NS: - schema_name = 'edgedb' - else: - schema_name = f'{ctx.namespace}_edgedb' - with cache.mutate() as cache_mm: for eql, args in meta_blocks: eql_hash = hashlib.sha1(eql.encode()).hexdigest() - fname = (schema_name, f'__rh_{eql_hash}') + fname = (pg_common.actual_schemaname('edgedb'), f'__rh_{eql_hash}') if eql_hash in cache_mm: argnames = cache_mm[eql_hash] @@ -787,7 +777,6 @@ def _compile_ql_query( expected_cardinality_one=ctx.expected_cardinality_one, output_format=_convert_format(ctx.output_format), backend_runtime_params=ctx.backend_runtime_params, - namespace=ctx.namespace ) if ( @@ -1100,11 +1089,6 @@ def _compile_and_apply_ddl_stmt( else: sql = (block.to_string().encode('utf-8'),) - if context.namespace == defines.DEFAULT_NS: - ns_prefix = '' - else: - ns_prefix = context.namespace + '_' - if new_types: # Inject a query returning backend OIDs for the newly # created types. @@ -1122,7 +1106,7 @@ def _compile_and_apply_ddl_stmt( "backend_id" ) FROM - {ns_prefix}edgedb."_SchemaType" + {pg_common.actual_schemaname('edgedb')}."_SchemaType" WHERE "id" = any(ARRAY[ {', '.join(new_type_ids)} @@ -1945,6 +1929,7 @@ def _compile( source: edgeql.Source, ) -> dbstate.QueryUnitGroup: current_tx = ctx.state.current_tx() + pg_common.NAMESPACE = ctx.namespace if current_tx.get_migration_state() is not None: original = edgeql.Source.from_string(source.text()) ctx = dataclasses.replace( diff --git a/edb/server/dbview/dbview.pxd b/edb/server/dbview/dbview.pxd index 44c15e14352..0a3e820fd43 100644 --- a/edb/server/dbview/dbview.pxd +++ b/edb/server/dbview/dbview.pxd @@ -224,8 +224,8 @@ cdef class DatabaseConnectionView: cpdef get_modaliases(self) cdef bytes serialize_state(self) - cdef bint is_state_desc_changed(self, namespace=?) + cdef bint is_state_desc_changed(self, namespace) cdef describe_state(self, namespace) cdef encode_state(self) - cdef decode_state(self, type_id, data, namespace=?) + cdef decode_state(self, type_id, data, namespace) cdef inline recode_global(self, serializer, k, v) diff --git a/edb/server/dbview/dbview.pyx b/edb/server/dbview/dbview.pyx index bc7cf847523..a72ff7e592f 100644 --- a/edb/server/dbview/dbview.pyx +++ b/edb/server/dbview/dbview.pyx @@ -576,7 +576,7 @@ cdef class Database: cdef get_state_serializer(self, namespace, protocol_version): if namespace not in self.ns_map: raise errors.InternalServerError( - f'NameSpace: [{namespace}] not in current db(ver:{self._db.dbver})' + f'NameSpace: [{namespace}] not in current db [{self.name}](ver:{self.dbver})' ) return self.ns_map[namespace].get_state_serializer(protocol_version) @@ -819,14 +819,14 @@ cdef class DatabaseConnectionView: if namespace in self._db.ns_map: return self._db.ns_map[namespace].user_schema raise errors.InternalServerError( - f'NameSpace: [{namespace}] not in current db(ver:{self._db.dbver})' + f'NameSpace: [{namespace}] not in current db [{self._db.name}](ver:{self._db.dbver})' ) def get_reflection_cache(self, namespace: str): if namespace in self._db.ns_map: return self._db.ns_map[namespace].reflection_cache raise errors.InternalServerError( - f'NameSpace: [{namespace}] not in current db(ver:{self._db.dbver})' + f'NameSpace: [{namespace}] not in current db [{self._db.name}](ver:{self._db.dbver})' ) def get_global_schema(self): @@ -859,7 +859,7 @@ cdef class DatabaseConnectionView: if namespace not in self._db.ns_map: raise errors.InternalServerError( - f'NameSpace: [{namespace}] not in current db(ver:{self._db.dbver})' + f'NameSpace: [{namespace}] not in current db [{self._db.name}](ver:{self._db.dbver})' ) tid = self._db.ns_map[namespace].backend_ids.get(type_id) @@ -893,7 +893,7 @@ cdef class DatabaseConnectionView: self._session_state_db_cache = (self._config, spec) return spec - cdef bint is_state_desc_changed(self, namespace=defines.DEFAULT_NS): + cdef bint is_state_desc_changed(self, namespace): serializer = self.get_state_serializer(namespace) if not self._in_tx: # We may have executed a query, or COMMIT/ROLLBACK - just use @@ -964,7 +964,7 @@ cdef class DatabaseConnectionView: state['globals'] = {k: v.value for k, v in globals_.items()} return serializer.type_id, serializer.encode(state) - cdef decode_state(self, type_id, data, namespace=defines.DEFAULT_NS): + cdef decode_state(self, type_id, data, namespace): if not self._in_tx: # make sure we start clean self._state_serializer = None @@ -1067,7 +1067,7 @@ cdef class DatabaseConnectionView: if key.namespace not in self._db.ns_map: raise errors.InternalServerError( - f'NameSpace: [{key.namespace}] not in current db(ver:{self._db.dbver})' + f'NameSpace: [{key.namespace}] not in current db [{self._db.name}](ver:{self._db.dbver})' ) ns = self._db.ns_map[key.namespace] diff --git a/edb/server/pgcon/pgcon.pyx b/edb/server/pgcon/pgcon.pyx index 96bc6f12410..f9143d41525 100644 --- a/edb/server/pgcon/pgcon.pyx +++ b/edb/server/pgcon/pgcon.pyx @@ -1947,8 +1947,7 @@ cdef class PGConnection: elif event == 'system-config-changes': self.server._on_remote_system_config_change() elif event == 'global-schema-changes': - namespace = event_payload['namespace'] - self.server._on_global_schema_change(namespace) + self.server._on_global_schema_change() else: raise AssertionError(f'unexpected system event: {event!r}') diff --git a/edb/server/protocol/binary.pxd b/edb/server/protocol/binary.pxd index f6738b14f16..5a255d22d9c 100644 --- a/edb/server/protocol/binary.pxd +++ b/edb/server/protocol/binary.pxd @@ -64,6 +64,7 @@ cdef class EdgeConnection(frontend.FrontendConnection): object loop readonly dbview.DatabaseConnectionView _dbview str dbname + str namespace ReadBuffer buffer diff --git a/edb/server/protocol/binary.pyx b/edb/server/protocol/binary.pyx index 230ad1db1d7..a1c370f12e5 100644 --- a/edb/server/protocol/binary.pyx +++ b/edb/server/protocol/binary.pyx @@ -22,6 +22,7 @@ import collections import hashlib import json import logging +import os import time import statistics import traceback @@ -166,6 +167,7 @@ cdef class EdgeConnection(frontend.FrontendConnection): self.loop = server.get_loop() self._dbview = None self.dbname = None + self.namespace = None self._transport = None self.buffer = ReadBuffer() @@ -476,7 +478,11 @@ cdef class EdgeConnection(frontend.FrontendConnection): f'accept connections' ) - await self._start_connection(database) + # for local test + namespace = params.get('namespace', os.environ.get('EDGEDB_TEST_CASES_NAMESPACE', defines.DEFAULT_NS)) + # namespace = params.get('namespace', edbdef.DEFAULT_NS) + + await self._start_connection(database, namespace) # The user has already been authenticated by other means # (such as the ability to write to a protected socket). @@ -515,7 +521,7 @@ cdef class EdgeConnection(frontend.FrontendConnection): msg_buf.write_bytes(b'\x00' * 32) msg_buf.end_message() buf.write_buffer(msg_buf) - # TODO add namespace + buf.write_buffer(self.make_state_data_description_msg()) self.write(buf) @@ -587,7 +593,7 @@ cdef class EdgeConnection(frontend.FrontendConnection): return params - async def _start_connection(self, database: str) -> None: + async def _start_connection(self, database: str, namespace: str) -> None: dbv = await self.server.new_dbview( dbname=database, query_cache=self.query_cache_enabled, @@ -596,8 +602,10 @@ cdef class EdgeConnection(frontend.FrontendConnection): assert type(dbv) is dbview.DatabaseConnectionView self._dbview = dbv self.dbname = database + self.namespace = namespace self._con_status = EDGECON_STARTED + logger.info(f'Connection started to {database}[{namespace}].') def stop_connection(self) -> None: self._con_status = EDGECON_BAD @@ -1005,10 +1013,10 @@ cdef class EdgeConnection(frontend.FrontendConnection): msg.end_message() return msg - cdef WriteBuffer make_state_data_description_msg(self, namespace=edbdef.DEFAULT_NS): + cdef WriteBuffer make_state_data_description_msg(self, namespace=None): cdef WriteBuffer msg - type_id, type_data = self.get_dbview().describe_state(namespace) + type_id, type_data = self.get_dbview().describe_state(namespace or self.namespace) msg = WriteBuffer.new_message(b's') msg.write_bytes(type_id.bytes) @@ -1141,10 +1149,8 @@ cdef class EdgeConnection(frontend.FrontendConnection): state_tid = self.buffer.read_bytes(16) state_data = self.buffer.read_len_prefixed_bytes() try: - # TODO add namespace - self.get_dbview().decode_state(state_tid, state_data) + self.get_dbview().decode_state(state_tid, state_data, self.namespace) except errors.StateMismatchError: - # TODO add namespace self.write(self.make_state_data_description_msg()) raise @@ -1158,6 +1164,7 @@ cdef class EdgeConnection(frontend.FrontendConnection): inline_typenames=inline_typenames, inline_objectids=inline_objectids, allow_capabilities=allow_capabilities, + namespace=self.namespace ) async def parse(self): @@ -1277,8 +1284,7 @@ cdef class EdgeConnection(frontend.FrontendConnection): if self._cancelled: raise ConnectionAbortedError - # TODO add namespace - if _dbview.is_state_desc_changed(): + if _dbview.is_state_desc_changed(self.namespace): self.write(self.make_state_data_description_msg()) self.write( self.make_command_complete_msg( @@ -1580,7 +1586,7 @@ cdef class EdgeConnection(frontend.FrontendConnection): # only use the backend if schema is required if static_exc is errormech.SchemaRequired: exc = errormech.interpret_backend_error( - self.get_dbview().get_schema(), + self.get_dbview().get_schema(self.namespace), exc.fields ) elif isinstance(static_exc, ( @@ -2038,7 +2044,7 @@ cdef class EdgeConnection(frontend.FrontendConnection): compiler_pool = server.get_compiler_pool() # TODO add namespace global_schema = _dbview.get_global_schema() - user_schema = _dbview.get_user_schema() + user_schema = _dbview.get_user_schema(edbdef.DEFAULT_NS) dump_server_ver_str = None headers_num = self.buffer.read_int16() @@ -2100,7 +2106,7 @@ cdef class EdgeConnection(frontend.FrontendConnection): self._in_dump_restore = True try: - _dbview.decode_state(sertypes.NULL_TYPE_ID.bytes, b'') + _dbview.decode_state(sertypes.NULL_TYPE_ID.bytes, b'', edbdef.DEFAULT_NS) await self._execute_utility_stmt( 'START TRANSACTION ISOLATION SERIALIZABLE', pgcon, @@ -2243,7 +2249,7 @@ cdef class EdgeConnection(frontend.FrontendConnection): await server.introspect(dbname) # TODO add namespace - if _dbview.is_state_desc_changed(): + if _dbview.is_state_desc_changed(edbdef.DEFAULT_NS): self.write(self.make_state_data_description_msg()) state_tid, state_data = _dbview.encode_state() diff --git a/edb/server/protocol/binary_v0.pyx b/edb/server/protocol/binary_v0.pyx index e381d959da6..477724a7e20 100644 --- a/edb/server/protocol/binary_v0.pyx +++ b/edb/server/protocol/binary_v0.pyx @@ -197,7 +197,9 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): f'accept connections' ) - await self._start_connection(database) + namespace = params.get('namespace', edbdef.DEFAULT_NS) + + await self._start_connection(database, namespace) # The user has already been authenticated by other means # (such as the ability to write to a protected socket). @@ -520,7 +522,7 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): self._in_dump_restore = True try: # TODO add namespace - _dbview.decode_state(sertypes.NULL_TYPE_ID.bytes, b'') + _dbview.decode_state(sertypes.NULL_TYPE_ID.bytes, b'', edbdef.DEFAULT_NS) await self._execute_utility_stmt( 'START TRANSACTION ISOLATION SERIALIZABLE', pgcon, @@ -935,7 +937,6 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): bint inline_objectids = True bytes stmt_name = b'' str module = None - str namespace = None bint read_only = False headers = self.legacy_parse_headers() @@ -955,8 +956,6 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): module = v.decode() elif k == QUERY_HEADER_PROHIBIT_MUTATION: read_only = parse_boolean(v, "PROHIBIT_MUTATION") - elif k == QUERY_HEADER_EXPLICIT_NS: - namespace = v.decode() else: raise errors.BinaryProtocolError( f'unexpected message header: {k}' @@ -991,7 +990,7 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): allow_capabilities=allow_capabilities, module=module, read_only=read_only, - namespace=namespace + namespace=self.namespace ) return eql, query_req, stmt_name @@ -1084,7 +1083,6 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): skip_first: bool, module: str = None, read_only: bool = False, - namespace: str = defines.DEFAULT_NS, ): query_req = dbview.QueryRequestInfo( source=edgeql.Source.from_string(query.decode("utf-8")), @@ -1092,7 +1090,7 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): output_format=FMT_NONE, module=module, read_only=read_only, - namespace=namespace, + namespace=self.namespace, ) return await self.get_dbview()._compile( @@ -1145,7 +1143,6 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): module = None read_only = False - namespace = defines.DEFAULT_NS headers = self.legacy_parse_headers() if headers: for k, v in headers.items(): @@ -1153,8 +1150,6 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): allow_capabilities = parse_capabilities_header(v) elif k == QUERY_HEADER_EXPLICIT_MODULE: module = v.decode() - elif k == QUERY_HEADER_EXPLICIT_NS: - namespace = v.decode() elif k == QUERY_HEADER_PROHIBIT_MUTATION: read_only = parse_boolean(v, "PROHIBIT_MUTATION") else: @@ -1192,7 +1187,7 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): query_unit = await self._legacy_simple_query( eql, allow_capabilities, skip_first, - module, read_only, namespace) + module, read_only) packet = WriteBuffer.new() packet.write_buffer(self.make_legacy_command_complete_msg(query_unit)) @@ -1207,7 +1202,6 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): skip_first: bint, module: str = None, read_only: bool = False, - namespace: str = defines.DEFAULT_NS, ): cdef: bytes state = None, orig_state = None @@ -1216,8 +1210,7 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): pgcon.PGConnection conn unit_group = await self._legacy_compile_script( - eql, skip_first=skip_first, module=module, read_only=read_only, - namespace=namespace + eql, skip_first=skip_first, module=module, read_only=read_only ) if self._cancelled: @@ -1255,7 +1248,7 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): if query_unit.create_ns: await self.server.create_namespace(conn, query_unit.create_ns) if query_unit.drop_ns: - await self.server._on_before_drop_ns(query_unit.drop_ns, _dbview.namespace) + await self.server._on_before_drop_ns(query_unit.drop_ns) if query_unit.system_config: await execute.execute_system_config(conn, _dbview, query_unit) else: diff --git a/edb/server/protocol/edgeql_ext.pyx b/edb/server/protocol/edgeql_ext.pyx index ea8b5086f5b..5fac18076fd 100644 --- a/edb/server/protocol/edgeql_ext.pyx +++ b/edb/server/protocol/edgeql_ext.pyx @@ -26,7 +26,7 @@ import immutables from edb import errors from edb import edgeql -from edb.server import defines as edbdef, defines +from edb.server import defines as edbdef from edb.server.protocol import execute from edb.common import debug @@ -64,7 +64,7 @@ async def handle_request( query = None module = None limit = 0 - namespace = defines.DEFAULT_NS + namespace = edbdef.DEFAULT_NS try: if request.method == b'POST': @@ -77,7 +77,7 @@ async def handle_request( variables = body.get('variables') globals_ = body.get('globals') module = body.get('module') - namespace = body.get('namespace', defines.DEFAULT_NS) + namespace = body.get('namespace', edbdef.DEFAULT_NS) limit = body.get('limit', 0) else: raise TypeError( @@ -116,7 +116,7 @@ async def handle_request( if namespace is not None: namespace = namespace[0] else: - namespace = defines.DEFAULT_NS + namespace = edbdef.DEFAULT_NS limit = qs.get('limit') if limit is not None: diff --git a/edb/server/protocol/execute.pyx b/edb/server/protocol/execute.pyx index 4df2ce6692c..91e9d6f12f6 100644 --- a/edb/server/protocol/execute.pyx +++ b/edb/server/protocol/execute.pyx @@ -432,7 +432,7 @@ async def parse_execute( dbv = await server.new_dbview( dbname=db.name, query_cache=False, - protocol_version=edbdef.CURRENT_PROTOCOL + protocol_version=edbdef.CURRENT_PROTOCOL, ) query_req = dbview.QueryRequestInfo( diff --git a/edb/server/protocol/schema_info.py b/edb/server/protocol/schema_info.py index 8b0fe22d76e..dfcc0572a38 100644 --- a/edb/server/protocol/schema_info.py +++ b/edb/server/protocol/schema_info.py @@ -112,7 +112,7 @@ async def handle_request( async def execute(db, server, namespace: str, query_uuid: str): if namespace not in db.ns_map: raise errors.InternalServerError( - f'NameSpace: [{namespace}] not in current db(ver:{db.dbver})' + f'NameSpace: [{namespace}] not in current db [{db.name}](ver:{db.dbver})' ) user_schema = db.ns_map[namespace].user_schema global_schema = server.get_global_schema() diff --git a/edb/server/server.py b/edb/server/server.py index 45eabba974f..5364bb9c36f 100644 --- a/edb/server/server.py +++ b/edb/server/server.py @@ -78,7 +78,17 @@ ADMIN_PLACEHOLDER = "" logger = logging.getLogger('edb.server') log_metrics = logging.getLogger('edb.server.metrics') -_RE_BYTES_REPL_NS = re.compile(rb'(edgedb)(\.|instdata|pub|ss|std|;)', flags=re.MULTILINE) +_RE_BYTES_REPL_NS = re.compile( + r'(current_setting\([\']+)?' + r'(edgedb)(\.|instdata|pub|ss|std|;)', +) + + +def repl_ignore_setting(match_obj): + maybe_setting, to_repl, tailing = match_obj.groups() + if maybe_setting: + return maybe_setting + to_repl + tailing + return "{ns_prefix}" + to_repl + tailing class RoleDescriptor(TypedDict): @@ -98,10 +108,11 @@ class Server(ha_base.ClusterProtocol): _roles: Mapping[str, RoleDescriptor] _instance_data: Mapping[str, str] _sys_queries: Mapping[str, str] - _local_intro_query: bytes + _local_intro_query: str _global_intro_query: bytes _report_config_typedesc: bytes _report_config_data: bytes + _ns_tpl_sql: Optional[str] _std_schema: s_schema.Schema _refl_schema: s_schema.Schema @@ -675,10 +686,7 @@ async def introspect_user_schema(self, dbname, namespace, conn): ns_prefix = '' else: ns_prefix = namespace + '_' - ns_intro_query = _RE_BYTES_REPL_NS.sub( - ns_prefix.encode('utf-8') + rb'\1\2', - self._local_intro_query - ) + ns_intro_query = self._local_intro_query.format(ns_prefix=ns_prefix).encode('utf-8') json_data = await conn.sql_fetch_val(ns_intro_query) base_schema = s_schema.ChainedSchema( @@ -936,11 +944,16 @@ async def _load_instance_data(self): self._sys_queries = immutables.Map( {k: q.encode() for k, q in queries.items()}) - self._local_intro_query = await syscon.sql_fetch_val(b'''\ + local_intro_query = await syscon.sql_fetch_val(b'''\ SELECT text FROM edgedbinstdata.instdata WHERE key = 'local_intro_query'; ''') + self._local_intro_query = _RE_BYTES_REPL_NS.sub( + r"{ns_prefix}\2\3", + local_intro_query.decode('utf-8'), + ) + self._global_intro_query = await syscon.sql_fetch_val(b'''\ SELECT text FROM edgedbinstdata.instdata WHERE key = 'global_intro_query'; @@ -984,9 +997,11 @@ async def _load_instance_data(self): if (tpldbdump := await get_tpl_sql(syscon)) is None: tpldbdump = await gen_tpl_dump(self._cluster) await store_tpl_sql(tpldbdump, syscon) - self._ns_tpl_sql = tpldbdump + ns_tpl_sql = tpldbdump.decode() else: - self._ns_tpl_sql = tpldbdump + ns_tpl_sql = tpldbdump.decode() + + self._ns_tpl_sql = _RE_BYTES_REPL_NS.sub(repl_ignore_setting, ns_tpl_sql) finally: self._release_sys_pgcon() @@ -1361,11 +1376,11 @@ async def task(): self.create_task(task(), interruptable=True) - def _on_global_schema_change(self, namespace): + def _on_global_schema_change(self): if not self._accept_new_tasks: return - async def task(ns): + async def task(): try: await self._reintrospect_global_schema() except Exception: @@ -1373,7 +1388,7 @@ async def task(ns): 1.0, 'on_global_schema_change') raise - self.create_task(task(namespace), interruptable=True) + self.create_task(task(), interruptable=True) def _on_sys_pgcon_connection_lost(self, exc): try: @@ -1951,17 +1966,8 @@ def on_switch_over(self): ) async def create_namespace(self, be_conn: pgcon.PGConnection, name: str): - tpl_sql = _RE_BYTES_REPL_NS.sub( - name.encode('utf-8') + rb'_\1\2', - self._ns_tpl_sql, - ) - tpl_sql = re.sub( - rb'({ns_edgedbext})', - name.encode('utf-8') + rb'_edgedbext', - tpl_sql, - flags=re.MULTILINE, - ) - await be_conn.sql_execute(tpl_sql) + tpl_sql = self._ns_tpl_sql.replace("{ns_prefix}", f"{name}_") + await be_conn.sql_execute(tpl_sql.encode('utf-8')) def get_active_pgcon_num(self) -> int: return ( From 0b0c4633dbebf70d57eea00916fc988af734216c Mon Sep 17 00:00:00 2001 From: JMPRIEST Date: Tue, 23 May 2023 19:08:21 +0800 Subject: [PATCH 14/20] =?UTF-8?q?:construction:=20=E5=AE=8C=E6=88=90namesp?= =?UTF-8?q?ace=E5=88=87=E6=8D=A2=E9=80=BB=E8=BE=91=EF=BC=8C=E5=BE=85?= =?UTF-8?q?=E8=A1=A5=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- edb/edgeql-parser/src/keywords.rs | 2 ++ edb/edgeql/ast.py | 4 +++ edb/edgeql/codegen.py | 7 +++++ edb/edgeql/parser/grammar/session.py | 22 +++++++++++++--- edb/server/compiler/compiler.py | 17 ++++++++++--- edb/server/compiler/dbstate.py | 11 ++++++++ edb/server/compiler/status.py | 5 ++++ edb/server/dbview/dbview.pyx | 38 ++++++++++------------------ edb/server/protocol/binary.pyx | 22 +++++++++------- 9 files changed, 88 insertions(+), 40 deletions(-) diff --git a/edb/edgeql-parser/src/keywords.rs b/edb/edgeql-parser/src/keywords.rs index 732c08e9db2..a2d760c8749 100644 --- a/edb/edgeql-parser/src/keywords.rs +++ b/edb/edgeql-parser/src/keywords.rs @@ -100,6 +100,8 @@ pub const UNRESERVED_KEYWORDS: &[&str] = &[ "version", "view", "write", + "use", + "show", ]; diff --git a/edb/edgeql/ast.py b/edb/edgeql/ast.py index 7d79b4ef2b0..775a1f74d26 100644 --- a/edb/edgeql/ast.py +++ b/edb/edgeql/ast.py @@ -226,6 +226,10 @@ class SessionResetAllAliases(BaseSessionReset): pass +class UseNameSpaceCommand(BaseSessionCommand): + name: str + + class BaseObjectRef(Base): __abstract_node__ = True diff --git a/edb/edgeql/codegen.py b/edb/edgeql/codegen.py index 6f5b1c4105d..65589a4e240 100644 --- a/edb/edgeql/codegen.py +++ b/edb/edgeql/codegen.py @@ -2191,6 +2191,13 @@ def visit_SessionResetAliasDecl( self._write_keywords('RESET ALIAS ') self.write(node.alias) + def visit_UseNameSpaceCommand( + self, + node: qlast.UseNameSpaceCommand + ) -> None: + self._write_keywords('USE NAMESPACE ') + self.write(node.name) + def visit_StartTransaction(self, node: qlast.StartTransaction) -> None: self._write_keywords('START TRANSACTION') diff --git a/edb/edgeql/parser/grammar/session.py b/edb/edgeql/parser/grammar/session.py index 4d7e932201f..1fc66ea8982 100644 --- a/edb/edgeql/parser/grammar/session.py +++ b/edb/edgeql/parser/grammar/session.py @@ -18,9 +18,7 @@ from __future__ import annotations -from edb.edgeql import ast as qlast - -from .expressions import Nonterm +from edb.pgsql import common as pg_common from .tokens import * # NOQA from .expressions import * # NOQA @@ -32,6 +30,12 @@ def reduce_SetStmt(self, *kids): def reduce_ResetStmt(self, *kids): self.val = kids[0].val + def reduce_UseNameSpaceStmt(self, *kids): + self.val = kids[0].val + + def reduce_ShowNameSpaceStmt(self, *kids): + self.val = kids[0].val + class SetStmt(Nonterm): def reduce_SET_ALIAS_Identifier_AS_MODULE_ModuleName(self, *kids): @@ -54,3 +58,15 @@ def reduce_RESET_MODULE(self, *kids): def reduce_RESET_ALIAS_STAR(self, *kids): self.val = qlast.SessionResetAllAliases() + + +class UseNameSpaceStmt(Nonterm): + def reduce_USE_NAMESPACE_Identifier(self, *kids): + self.val = qlast.UseNameSpaceCommand(name=kids[2].val) + + +class ShowNameSpaceStmt(Nonterm): + def reduce_SHOW_NAMESPACE(self, *kids): + self.val = qlast.SelectQuery( + result=qlast.StringConstant(value=pg_common.NAMESPACE) + ) diff --git a/edb/server/compiler/compiler.py b/edb/server/compiler/compiler.py index f65da65d53b..648dc4a8f74 100644 --- a/edb/server/compiler/compiler.py +++ b/edb/server/compiler/compiler.py @@ -1899,7 +1899,11 @@ def _compile_dispatch_ql( self._compile_ql_sess_state(ctx, ql), enums.Capability.SESSION_CONFIG, ) - + elif isinstance(ql, qlast.UseNameSpaceCommand): + return ( + dbstate.NameSpaceSwitchQuery(new_ns=ql.name, sql=()), + enums.Capability.SESSION_CONFIG, + ) elif isinstance(ql, qlast.ConfigOp): if ql.scope is qltypes.ConfigScope.SESSION: capability = enums.Capability.SESSION_CONFIG @@ -1969,6 +1973,12 @@ def _try_compile( default_cardinality = enums.Cardinality.NO_RESULT statements = edgeql.parse_block(source) statements_len = len(statements) + is_script = statements_len > 1 + + if is_script and any(isinstance(stmt, qlast.UseNameSpaceCommand) for stmt in statements): + raise errors.ProtocolError( + 'USE NAMESPACE statement is not allowed to be used in script.' + ) if ctx.skip_first: statements = statements[1:] @@ -1984,8 +1994,6 @@ def _try_compile( rv = dbstate.QueryUnitGroup() rv.namespace = ctx.namespace - - is_script = statements_len > 1 script_info = None if is_script: if ctx.expect_rollback: @@ -2213,7 +2221,8 @@ def _try_compile( unit.config_ops.append(comp.config_op) unit.has_set = True - + elif isinstance(comp, dbstate.NameSpaceSwitchQuery): + unit.ns_to_switch = comp.new_ns elif isinstance(comp, dbstate.NullQuery): pass diff --git a/edb/server/compiler/dbstate.py b/edb/server/compiler/dbstate.py index 38919e654f5..6b7e0491e53 100644 --- a/edb/server/compiler/dbstate.py +++ b/edb/server/compiler/dbstate.py @@ -80,6 +80,13 @@ class NullQuery(BaseQuery): has_dml: bool = False +@dataclasses.dataclass(frozen=True) +class NameSpaceSwitchQuery(BaseQuery): + new_ns: str + is_transactional: bool = False + single_unit: bool = True + + @dataclasses.dataclass(frozen=True) class Query(BaseQuery): @@ -316,7 +323,10 @@ class QueryUnit: # schema reflection sqls, only available if this is a ddl stmt. schema_refl_sqls: Tuple[bytes, ...] = None stdview_sqls: Tuple[bytes, ...] = None + # NameSpace to use for current compile namespace: str = defines.DEFAULT_NS + # NameSpace to switch for connection + ns_to_switch: str = None @property def has_ddl(self) -> bool: @@ -372,6 +382,7 @@ class QueryUnitGroup: ref_ids: Optional[Set[uuid.UUID]] = None # Record affected object ids for cache clear affected_obj_ids: Optional[Set[uuid.UUID]] = None + # NameSpace to use for current compile namespace: str = defines.DEFAULT_NS def __iter__(self): diff --git a/edb/server/compiler/status.py b/edb/server/compiler/status.py index 5b54b67afa7..50febc28cc2 100644 --- a/edb/server/compiler/status.py +++ b/edb/server/compiler/status.py @@ -155,6 +155,11 @@ def _sess_reset_alias(ql): return b'RESET ALIAS' +@get_status.register(qlast.UseNameSpaceCommand) +def _sess_use_ns(ql): + return f'USE NAMESPACE {ql.name}'.encode() + + @get_status.register(qlast.ConfigOp) def _sess_set_config(ql): if ql.scope == qltypes.ConfigScope.GLOBAL: diff --git a/edb/server/dbview/dbview.pyx b/edb/server/dbview/dbview.pyx index a72ff7e592f..61ecad08d7b 100644 --- a/edb/server/dbview/dbview.pyx +++ b/edb/server/dbview/dbview.pyx @@ -574,10 +574,6 @@ cdef class Database: self._views.remove(view) cdef get_state_serializer(self, namespace, protocol_version): - if namespace not in self.ns_map: - raise errors.InternalServerError( - f'NameSpace: [{namespace}] not in current db [{self.name}](ver:{self.dbver})' - ) return self.ns_map[namespace].get_state_serializer(protocol_version) def iter_views(self): @@ -725,6 +721,7 @@ cdef class DatabaseConnectionView: return self._in_tx_state_serializer else: if self._state_serializer is None: + self.valid_namespace(namespace) # Executed a DDL, recalculate the state descriptor self._state_serializer = self._db.get_state_serializer( namespace, @@ -816,18 +813,19 @@ cdef class DatabaseConnectionView: self._in_tx_user_schema_mut_pickled = None return self._in_tx_user_schema else: - if namespace in self._db.ns_map: - return self._db.ns_map[namespace].user_schema - raise errors.InternalServerError( - f'NameSpace: [{namespace}] not in current db [{self._db.name}](ver:{self._db.dbver})' - ) + self.valid_namespace(namespace) + return self._db.ns_map[namespace].user_schema def get_reflection_cache(self, namespace: str): - if namespace in self._db.ns_map: - return self._db.ns_map[namespace].reflection_cache - raise errors.InternalServerError( - f'NameSpace: [{namespace}] not in current db [{self._db.name}](ver:{self._db.dbver})' - ) + self.valid_namespace(namespace) + return self._db.ns_map[namespace].reflection_cache + + def valid_namespace(self, namespace: str): + if namespace not in self._db.ns_map: + raise errors.QueryError( + f'NameSpace: [{namespace}] not in current db [{self._db.name}](ver:{self._db.dbver}).' + f'Current NameSpace(s): [{", ".join(self._db.ns_map.keys())}]' + ) def get_global_schema(self): if self._in_tx: @@ -857,11 +855,7 @@ cdef class DatabaseConnectionView: except KeyError: pass - if namespace not in self._db.ns_map: - raise errors.InternalServerError( - f'NameSpace: [{namespace}] not in current db [{self._db.name}](ver:{self._db.dbver})' - ) - + self.valid_namespace(namespace) tid = self._db.ns_map[namespace].backend_ids.get(type_id) if tid is None: raise RuntimeError( @@ -1065,11 +1059,7 @@ cdef class DatabaseConnectionView: self._in_tx_with_ddl): return None - if key.namespace not in self._db.ns_map: - raise errors.InternalServerError( - f'NameSpace: [{key.namespace}] not in current db [{self._db.name}](ver:{self._db.dbver})' - ) - + self.valid_namespace(key.namespace) ns = self._db.ns_map[key.namespace] key = (key, self.get_modaliases(), self.get_session_config()) diff --git a/edb/server/protocol/binary.pyx b/edb/server/protocol/binary.pyx index a1c370f12e5..709f4ae1949 100644 --- a/edb/server/protocol/binary.pyx +++ b/edb/server/protocol/binary.pyx @@ -22,7 +22,6 @@ import collections import hashlib import json import logging -import os import time import statistics import traceback @@ -478,9 +477,7 @@ cdef class EdgeConnection(frontend.FrontendConnection): f'accept connections' ) - # for local test - namespace = params.get('namespace', os.environ.get('EDGEDB_TEST_CASES_NAMESPACE', defines.DEFAULT_NS)) - # namespace = params.get('namespace', edbdef.DEFAULT_NS) + namespace = params.get('namespace', edbdef.DEFAULT_NS) await self._start_connection(database, namespace) @@ -602,6 +599,7 @@ cdef class EdgeConnection(frontend.FrontendConnection): assert type(dbv) is dbview.DatabaseConnectionView self._dbview = dbv self.dbname = database + dbv.valid_namespace(namespace) self.namespace = namespace self._con_status = EDGECON_STARTED @@ -1275,11 +1273,17 @@ cdef class EdgeConnection(frontend.FrontendConnection): elif len(query_unit_group) > 1: await self._execute_script(compiled, args) else: - use_prep = ( - len(query_unit_group) == 1 - and bool(query_unit_group[0].sql_hash) - ) - await self._execute(compiled, args, use_prep) + if len(query_unit_group) == 1 and query_unit_group[0].ns_to_switch is not None: + new_ns = query_unit_group[0].ns_to_switch + self.get_dbview().valid_namespace(new_ns) + self.namespace = query_unit_group[0].ns_to_switch + logger.info(f'NameSpace changed to {query_unit_group[0].ns_to_switch} in current connection.') + else: + use_prep = ( + len(query_unit_group) == 1 + and bool(query_unit_group[0].sql_hash) + ) + await self._execute(compiled, args, use_prep) if self._cancelled: raise ConnectionAbortedError From edd33f6ad9a5edce388b36aca9d6a81a02bff23e Mon Sep 17 00:00:00 2001 From: JMPRIEST Date: Tue, 30 May 2023 16:16:22 +0800 Subject: [PATCH 15/20] =?UTF-8?q?:sparkles:=20dump&restore=E5=AF=B9namespa?= =?UTF-8?q?ce=E7=9A=84=E6=94=AF=E6=8C=81=20:white=5Fcheck=5Fmark:=20?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0namespace=E6=B5=8B=E8=AF=95=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- edb/edgeql/parser/grammar/ddl.py | 60 ++--- edb/graphql/extension.pyx | 13 +- edb/ir/ast.py | 2 + edb/ir/typeutils.py | 6 + edb/pgsql/codegen.py | 5 +- edb/pgsql/common.py | 14 +- edb/pgsql/compiler/config.py | 4 +- edb/pgsql/compiler/dml.py | 2 +- edb/pgsql/dbops/base.py | 6 +- edb/pgsql/types.py | 8 +- edb/schema/schema.py | 2 +- edb/server/bootstrap.py | 8 + edb/server/compiler/compiler.py | 23 +- edb/server/compiler/errormech.py | 2 +- edb/server/dbview/dbview.pxd | 14 +- edb/server/dbview/dbview.pyx | 25 +- edb/server/protocol/args_ser.pyx | 2 +- edb/server/protocol/binary.pyx | 391 ++++++++++++++++------------ edb/server/protocol/binary_v0.pyx | 390 +-------------------------- edb/server/protocol/consts.pxi | 2 + edb/server/protocol/infer_expr.py | 9 +- edb/server/server.py | 21 +- edb/testbase/http.py | 13 +- edb/testbase/server.py | 40 ++- tests/test_edgeql_data_migration.py | 1 + tests/test_http_graphql_query.py | 3 +- tests/test_namespace.py | 137 ++++++++++ 27 files changed, 545 insertions(+), 658 deletions(-) create mode 100644 tests/test_namespace.py diff --git a/edb/edgeql/parser/grammar/ddl.py b/edb/edgeql/parser/grammar/ddl.py index 8a8c92d9d49..d319f386449 100644 --- a/edb/edgeql/parser/grammar/ddl.py +++ b/edb/edgeql/parser/grammar/ddl.py @@ -667,34 +667,10 @@ class DropDatabaseStmt(Nonterm): def reduce_DROP_DATABASE_DatabaseName(self, *kids): self.val = qlast.DropDatabase(name=kids[2].val) + # # NAMESPACE # - - -class NameSpaceName(Nonterm): - - def reduce_Identifier(self, kid): - self.val = qlast.ObjectRef( - module=None, - name=kid.val - ) - - def reduce_ReservedKeyword(self, *kids): - name = kids[0].val - if ( - name[:2] == '__' and name[-2:] == '__' - ): - raise EdgeQLSyntaxError( - "identifiers surrounded by double underscores are forbidden", - context=kids[0].context) - - self.val = qlast.ObjectRef( - module=None, - name=name - ) - - class NameSpaceStmt(Nonterm): def reduce_CreateNameSpaceStmt(self, *kids): @@ -704,30 +680,24 @@ def reduce_DropNameSpaceStmt(self, *kids): self.val = kids[0].val -# -# CREATE NAMESPACE -# - - -commands_block( - 'CreateNameSpace', - SetFieldStmt, -) - - class CreateNameSpaceStmt(Nonterm): - def reduce_CREATE_NAMESPACE_NameSpaceName(self, *kids): - """%reduce CREATE NAMESPACE NameSpaceName - """ - self.val = qlast.CreateNameSpace(name=kids[2].val) + def reduce_CREATE_NAMESPACE_Identifier(self, *kids): + self.val = qlast.CreateNameSpace( + name=qlast.ObjectRef( + module=None, + name=kids[2].val + ) + ) -# -# DROP NAMESPACE -# class DropNameSpaceStmt(Nonterm): - def reduce_DROP_NAMESPACE_DatabaseName(self, *kids): - self.val = qlast.DropNameSpace(name=kids[2].val) + def reduce_DROP_NAMESPACE_Identifier(self, *kids): + self.val = qlast.DropNameSpace( + name=qlast.ObjectRef( + module=None, + name=kids[2].val + ) + ) # diff --git a/edb/graphql/extension.pyx b/edb/graphql/extension.pyx index c1e3f2ce81e..61f5b6201ef 100644 --- a/edb/graphql/extension.pyx +++ b/edb/graphql/extension.pyx @@ -37,7 +37,7 @@ from edb import _graphql_rewrite from edb import errors from edb.graphql import errors as gql_errors from edb.server.dbview cimport dbview -from edb.server import compiler, defines +from edb.server import compiler from edb.server import defines as edbdef from edb.server.pgcon import errors as pgerrors from edb.server.protocol import execute @@ -97,7 +97,7 @@ async def handle_request( globals = None query = None module = None - namespace = defines.DEFAULT_NS + namespace = edbdef.DEFAULT_NS limit = 0 try: @@ -112,7 +112,7 @@ async def handle_request( variables = body.get('variables') module = body.get('module') limit = body.get('limit', 0) - namespace = body.get('namespace', defines.DEFAULT_NS) + namespace = body.get('namespace', edbdef.DEFAULT_NS) globals = body.get('globals') elif request.content_type == 'application/graphql': query = request.body.decode('utf-8') @@ -163,7 +163,7 @@ async def handle_request( if namespace is not None: namespace = namespace[0] else: - namespace = defines.DEFAULT_NS + namespace = edbdef.DEFAULT_NS else: raise TypeError('expected a GET or a POST request') @@ -261,8 +261,9 @@ async def _execute( ): if namespace not in db.ns_map: - raise errors.InternalServerError( - f'NameSpace: [{namespace}] not in current db [{db.name}](ver:{db.dbver})' + raise errors.QueryError( + f'NameSpace: [{namespace}] not in current db [{db.name}](ver:{db.dbver}).' + f'Current NameSpace(s): [{", ".join(db.ns_map.keys())}]' ) ns = db.ns_map[namespace] diff --git a/edb/ir/ast.py b/edb/ir/ast.py index 44851564843..b8b05615964 100644 --- a/edb/ir/ast.py +++ b/edb/ir/ast.py @@ -162,6 +162,8 @@ class TypeRef(ImmutableBase): is_opaque_union: bool = False # True, if this describes an sequnce type is_sequence: bool = False + # True, if this contains enums + has_enum: bool = False def __repr__(self) -> str: return f'' diff --git a/edb/ir/typeutils.py b/edb/ir/typeutils.py index 4d5dc81e5a5..b9a54eedbbf 100644 --- a/edb/ir/typeutils.py +++ b/edb/ir/typeutils.py @@ -322,6 +322,11 @@ def type_to_typeref( else: ancestors = None + if isinstance(t, s_scalars.ScalarType): + has_enum = (t.get_enum_values(schema) is not None) + else: + has_enum = False + result = irast.TypeRef( id=t.id, name_hint=name, @@ -338,6 +343,7 @@ def type_to_typeref( is_abstract=t.get_abstract(schema), is_view=t.is_view(schema), is_opaque_union=t.get_is_opaque_union(schema), + has_enum=has_enum, ) elif isinstance(t, s_types.Tuple) and t.is_named(schema): schema, material_type = t.material_type(schema) diff --git a/edb/pgsql/codegen.py b/edb/pgsql/codegen.py index de9ed46073f..536ee05da3e 100644 --- a/edb/pgsql/codegen.py +++ b/edb/pgsql/codegen.py @@ -119,10 +119,7 @@ def visit_Relation(self, node): if node.schemaname is None: self.write(common.qname(node.name)) else: - if node.schemaname in defines.EDGEDB_OWNED_DBS: - self.write(common.qname(common.actual_schemaname(node.schemaname), node.name)) - else: - self.write(common.qname(node.schemaname, node.name)) + self.write(common.qname(common.actual_schemaname(node.schemaname), node.name)) def _visit_values_expr(self, node): self.new_lines = 1 diff --git a/edb/pgsql/common.py b/edb/pgsql/common.py index 36c31637871..30ac73d88d9 100644 --- a/edb/pgsql/common.py +++ b/edb/pgsql/common.py @@ -128,8 +128,10 @@ def quote_type(type_): return first + last -def get_module_backend_name(module: s_name.Name) -> str: +def get_module_backend_name(module: s_name.Name, ignore_ns=False) -> str: # standard modules go into "edgedbstd", user ones into "edgedbpub" + if ignore_ns: + return "edgedbstd" if module in s_schema.STD_MODULES else "edgedbpub" return actual_schemaname("edgedbstd") if module in s_schema.STD_MODULES else actual_schemaname("edgedbpub") @@ -188,8 +190,8 @@ def edgedb_name_to_pg_name(name: str, prefix_length: int = 0) -> str: return _edgedb_name_to_pg_name(name, prefix_length) -def convert_name(name, suffix='', catenate=True): - schema = get_module_backend_name(name.get_module_name()) +def convert_name(name, suffix='', catenate=True, ignore_ns=False): + schema = get_module_backend_name(name.get_module_name(), ignore_ns) if suffix: sname = f'{name.name}_{suffix}' else: @@ -223,7 +225,7 @@ def update_aspect(name, aspect): return (name[0], stripped) -def get_scalar_backend_name(id, module_name, catenate=True, *, aspect=None): +def get_scalar_backend_name(id, module_name, catenate=True, *, aspect=None, ignore_ns=False): if aspect is None: aspect = 'domain' if aspect not in ( @@ -233,7 +235,7 @@ def get_scalar_backend_name(id, module_name, catenate=True, *, aspect=None): raise ValueError( f'unexpected aspect for scalar backend name: {aspect!r}') name = s_name.QualName(module=module_name, name=str(id)) - return convert_name(name, aspect, catenate) + return convert_name(name, aspect, catenate, ignore_ns) def get_aspect_suffix(aspect): @@ -368,7 +370,7 @@ def get_index_backend_name(id, module_name, catenate=True, *, aspect=None): def get_tuple_backend_name(id, catenate=True, *, aspect=None): name = s_name.QualName(module='edgedb', name=f'{id}_t') - return convert_name(name, aspect, catenate) + return convert_name(name, aspect, catenate, ignore_ns=True) def get_backend_name(schema, obj, catenate=True, *, aspect=None): diff --git a/edb/pgsql/compiler/config.py b/edb/pgsql/compiler/config.py index 03a8d086d71..15541e25ab0 100644 --- a/edb/pgsql/compiler/config.py +++ b/edb/pgsql/compiler/config.py @@ -66,7 +66,7 @@ def compile_ConfigSet( ) fcall = pgast.FuncCall( - name=(common.actual_schemaname('edgedb'), '_alter_current_database_set'), + name=('edgedb', '_alter_current_database_set'), args=[pgast.StringConstant(val=op.backend_setting), val], ) @@ -257,7 +257,7 @@ def compile_ConfigReset( elif op.scope is qltypes.ConfigScope.DATABASE and op.backend_setting: fcall = pgast.FuncCall( - name=(common.actual_schemaname('edgedb'), '_alter_current_database_set'), + name=('edgedb', '_alter_current_database_set'), args=[ pgast.StringConstant(val=op.backend_setting), pgast.NullConstant(), diff --git a/edb/pgsql/compiler/dml.py b/edb/pgsql/compiler/dml.py index f35ad60c250..58e45743707 100644 --- a/edb/pgsql/compiler/dml.py +++ b/edb/pgsql/compiler/dml.py @@ -2372,7 +2372,7 @@ def process_link_values( ): if src_prop.out_target.is_sequence: seq_backend_name = pgast.StringConstant( - val=f'"edgedbpub"."{src_prop.out_target.id}_sequence"' + val=f'"{common.actual_schemaname("edgedbpub")}"."{src_prop.out_target.id}_sequence"' ) source_val = pgast.FuncCall( name=('currval', ), diff --git a/edb/pgsql/dbops/base.py b/edb/pgsql/dbops/base.py index 41047ede2ee..aae34aa6d24 100644 --- a/edb/pgsql/dbops/base.py +++ b/edb/pgsql/dbops/base.py @@ -225,9 +225,9 @@ def declare_var( self, type_name: Union[str, Tuple[str, str]], *, - var_name: str = '', - var_name_prefix: str = 'v', - shared: bool = False, + var_name: str='', + var_name_prefix: str='v', + shared: bool=False, ) -> str: if shared: if not var_name: diff --git a/edb/pgsql/types.py b/edb/pgsql/types.py index c9d75d4a573..bc529ca425b 100644 --- a/edb/pgsql/types.py +++ b/edb/pgsql/types.py @@ -307,9 +307,15 @@ def pg_type_from_ir_typeref( else: pg_type = base_type_name_map.get(material.id) if pg_type is None: + builtin_extending_enum = ( + material.has_enum + and str(material.name_hint.module) in s_schema.STD_MODULES_STR + ) # User-defined scalar type pg_type = common.get_scalar_backend_name( - material.id, material.name_hint.module, catenate=False) + material.id, material.name_hint.module, catenate=False, + ignore_ns=builtin_extending_enum + ) return pg_type diff --git a/edb/schema/schema.py b/edb/schema/schema.py index d5c6571b222..90cc4b30776 100644 --- a/edb/schema/schema.py +++ b/edb/schema/schema.py @@ -66,7 +66,7 @@ BUILTIN_MODULES = STD_MODULES + (sn.UnqualName('builtin'), ) -STD_MODULES_STR = {'sys', 'schema', 'cal', 'math'} +STD_MODULES_STR = {'std', 'sys', 'schema', 'cal', 'math', 'cfg', 'builtin'} # Specifies the order of processing of files and directories in lib/ STD_SOURCES = ( diff --git a/edb/server/bootstrap.py b/edb/server/bootstrap.py index 8914e1906ec..27d71791ba4 100644 --- a/edb/server/bootstrap.py +++ b/edb/server/bootstrap.py @@ -1018,6 +1018,14 @@ async def gen_tpl_dump(cluster: pgcluster.BaseCluster): exclude_schemas=['edgedbext'], dump_object_owners=False, ) + # exclude create type & domain + tpldbdump = re.sub( + rb'CREATE (?:(TYPE|DOMAIN))[^;]*;', + rb'', + tpldbdump, + flags=re.DOTALL + ) + commands = [dbops.CreateSchema(name='{ns_prefix}edgedbext')] for uuid_func in [ 'uuid_generate_v1', diff --git a/edb/server/compiler/compiler.py b/edb/server/compiler/compiler.py index 648dc4a8f74..22d8cc3e5e8 100644 --- a/edb/server/compiler/compiler.py +++ b/edb/server/compiler/compiler.py @@ -1980,6 +1980,11 @@ def _try_compile( 'USE NAMESPACE statement is not allowed to be used in script.' ) + if isinstance(statements[0], qlast.UseNameSpaceCommand) and ctx.in_tx: + raise errors.ProtocolError( + 'USE NAMESPACE statement is not allowed to be used in transaction.' + ) + if ctx.skip_first: statements = statements[1:] if not statements: # pragma: no cover @@ -2601,12 +2606,13 @@ def compile_in_tx( def describe_database_dump( self, + namespace: str, user_schema: s_schema.Schema, global_schema: s_schema.Schema, database_config: immutables.Map[str, config.SettingValue], protocol_version: Tuple[int, int], ) -> DumpDescriptor: - # TODO namespace支持 + pg_common.NAMESPACE = namespace schema = s_schema.ChainedSchema( self._std_schema, user_schema, @@ -2867,6 +2873,7 @@ def _check_dump_layout( def describe_database_restore( self, + namespace, user_schema: s_schema.Schema, global_schema: s_schema.Schema, dump_server_ver_str: Optional[str], @@ -2876,7 +2883,6 @@ def describe_database_restore( protocol_version: Tuple[int, int], external_view: Dict[str, str] ) -> RestoreDescriptor: - # TODO namespace支持 schema_object_ids = { ( s_name.name_from_string(name), @@ -2927,7 +2933,9 @@ def describe_database_restore( log_ddl_as_migrations=False, protocol_version=protocol_version, external_view=external_view, - restoring_external=True + restoring_external=True, + namespace=namespace, + bootstrap_mode=True ) else: ctx = CompileContext( @@ -2938,6 +2946,8 @@ def describe_database_restore( schema_object_ids=schema_object_ids, log_ddl_as_migrations=False, protocol_version=protocol_version, + namespace=namespace, + bootstrap_mode=True ) ctx.state.start_tx() @@ -3231,3 +3241,10 @@ class RestoreBlockDescriptor(NamedTuple): #: this will contain the recursive descriptor on which parts of #: each datum need mending. data_mending_desc: Tuple[Optional[DataMendingDescriptor], ...] + + +class RestoreSchemaInfo(NamedTuple): + schema_ddl: bytes + schema_ids: List[Tuple] + blocks: List + external_views: List[Tuple] diff --git a/edb/server/compiler/errormech.py b/edb/server/compiler/errormech.py index 78aea17faeb..37e37528d69 100644 --- a/edb/server/compiler/errormech.py +++ b/edb/server/compiler/errormech.py @@ -122,7 +122,7 @@ class ErrorDetails(NamedTuple): pgtype_re = re.compile( '|'.join(fr'\b{key}\b' for key in types.base_type_name_map_r)) enum_re = re.compile( - r'(?P

enum) (?Pedgedb([\w-]+)."(?P[\w-]+)_domain")') + r'(?P

enum) (?P(?:(.*)_)?edgedb([\w-]+)."(?P[\w-]+)_domain")') def translate_pgtype(schema, msg): diff --git a/edb/server/dbview/dbview.pxd b/edb/server/dbview/dbview.pxd index 0a3e820fd43..c38af7fc30a 100644 --- a/edb/server/dbview/dbview.pxd +++ b/edb/server/dbview/dbview.pxd @@ -183,7 +183,7 @@ cdef class DatabaseConnectionView: cdef rollback_tx_to_savepoint(self, name) cdef declare_savepoint(self, namespace, name, spid) cdef recover_aliases_and_config(self, modaliases, config, globals) - cdef abort_tx(self) + cpdef abort_tx(self) cpdef in_tx(self) cpdef in_tx_error(self) @@ -193,11 +193,11 @@ cdef class DatabaseConnectionView: cdef tx_error(self) - cdef start(self, query_unit) + cpdef start(self, query_unit) cdef _start_tx(self, namespace) cdef _apply_in_tx(self, query_unit) cdef start_implicit(self, query_unit) - cdef on_error(self) + cpdef on_error(self) cdef commit_implicit_tx( self, namespace, user_schema, user_schema_unpacked, user_schema_mutation, global_schema, @@ -224,8 +224,8 @@ cdef class DatabaseConnectionView: cpdef get_modaliases(self) cdef bytes serialize_state(self) - cdef bint is_state_desc_changed(self, namespace) + cpdef bint is_state_desc_changed(self, namespace) cdef describe_state(self, namespace) - cdef encode_state(self) - cdef decode_state(self, type_id, data, namespace) - cdef inline recode_global(self, serializer, k, v) + cpdef encode_state(self) + cpdef decode_state(self, type_id, data, namespace) + cdef inline recode_global(self, serializer, namespace, k, v) diff --git a/edb/server/dbview/dbview.pyx b/edb/server/dbview/dbview.pyx index 61ecad08d7b..3b99efa138a 100644 --- a/edb/server/dbview/dbview.pyx +++ b/edb/server/dbview/dbview.pyx @@ -692,7 +692,7 @@ cdef class DatabaseConnectionView: self.set_session_config(config) self.set_globals(globals) - cdef abort_tx(self): + cpdef abort_tx(self): if not self.in_tx(): raise errors.InternalServerError('abort_tx(): not in transaction') self._reset_tx_state() @@ -846,7 +846,7 @@ cdef class DatabaseConnectionView: self._db._index._global_schema, ) - def resolve_backend_type_id(self, type_id, namespace: defines.DEFAULT_NS): + def resolve_backend_type_id(self, type_id, namespace): type_id = str(type_id) if self._in_tx: @@ -887,7 +887,7 @@ cdef class DatabaseConnectionView: self._session_state_db_cache = (self._config, spec) return spec - cdef bint is_state_desc_changed(self, namespace): + cpdef bint is_state_desc_changed(self, namespace): serializer = self.get_state_serializer(namespace) if not self._in_tx: # We may have executed a query, or COMMIT/ROLLBACK - just use @@ -918,7 +918,7 @@ cdef class DatabaseConnectionView: cdef describe_state(self, namespace): return self.get_state_serializer(namespace).describe() - cdef encode_state(self): + cpdef encode_state(self): modaliases = self.get_modaliases() session_config = self.get_session_config() globals_ = self.get_globals() @@ -958,7 +958,7 @@ cdef class DatabaseConnectionView: state['globals'] = {k: v.value for k, v in globals_.items()} return serializer.type_id, serializer.encode(state) - cdef decode_state(self, type_id, data, namespace): + cpdef decode_state(self, type_id, data, namespace): if not self._in_tx: # make sure we start clean self._state_serializer = None @@ -998,7 +998,7 @@ cdef class DatabaseConnectionView: globals_ = immutables.Map({ k: config.SettingValue( name=k, - value=self.recode_global(serializer, k, v), + value=self.recode_global(serializer, namespace, k, v), source='global', scope=qltypes.ConfigScope.GLOBAL, ) for k, v in state.get('globals', {}).items() @@ -1010,13 +1010,13 @@ cdef class DatabaseConnectionView: aliases, session_config, globals_, type_id, data ) - cdef inline recode_global(self, serializer, k, v): + cdef inline recode_global(self, serializer, namespace, k, v): if v and v[:4] == b'\x00\x00\x00\x01': array_type_id = serializer.get_global_array_type_id(k) if array_type_id: va = bytearray(v) va[8:12] = INT32_PACKER( - self.resolve_backend_type_id(array_type_id) + self.resolve_backend_type_id(array_type_id, namespace) ) v = bytes(va) return v @@ -1039,6 +1039,9 @@ cdef class DatabaseConnectionView: def server(self): return self._db._index._server + def iter_ns_name(self): + return iter(self._db.ns_map.keys()) + cpdef in_tx(self): return self._in_tx @@ -1095,7 +1098,7 @@ cdef class DatabaseConnectionView: if self._in_tx: self._tx_error = True - cdef start(self, query_unit): + cpdef start(self, query_unit): if self._tx_error: self.raise_in_tx_error() @@ -1152,7 +1155,7 @@ cdef class DatabaseConnectionView: self._apply_in_tx(query_unit) - cdef on_error(self): + cpdef on_error(self): self.tx_error() async def in_tx_persist_schema(self, be_conn): @@ -1604,7 +1607,7 @@ cdef class DatabaseIndex: def unregister_ns(self, dbname, namespace): if dbname not in self._dbs: return - self._dbs[dbname].ns_map.pop(namespace) + self._dbs[dbname].ns_map.pop(namespace, None) def unregister_db(self, dbname): self._dbs.pop(dbname) diff --git a/edb/server/protocol/args_ser.pyx b/edb/server/protocol/args_ser.pyx index b55d904d5b0..0505cc2f9ed 100644 --- a/edb/server/protocol/args_ser.pyx +++ b/edb/server/protocol/args_ser.pyx @@ -171,7 +171,7 @@ cdef WriteBuffer recode_bind_args( if param.array_type_id is not None: # ndimensions + flags array_tid = dbv.resolve_backend_type_id( - param.array_type_id) + param.array_type_id, qug.namespace) out_buf.write_cstr(data, 8) out_buf.write_int32(array_tid) out_buf.write_cstr(&data[12], in_len - 12) diff --git a/edb/server/protocol/binary.pyx b/edb/server/protocol/binary.pyx index 709f4ae1949..ef0aeeebd4e 100644 --- a/edb/server/protocol/binary.pyx +++ b/edb/server/protocol/binary.pyx @@ -67,6 +67,7 @@ from edb.server import defines as edbdef from edb.server.compiler import errormech from edb.server.compiler import enums from edb.server.compiler import sertypes +from edb.server.compiler.compiler import RestoreSchemaInfo from edb.server.protocol import execute from edb.server.protocol cimport frontend from edb.server.pgcon cimport pgcon @@ -603,7 +604,6 @@ cdef class EdgeConnection(frontend.FrontendConnection): self.namespace = namespace self._con_status = EDGECON_STARTED - logger.info(f'Connection started to {database}[{namespace}].') def stop_connection(self) -> None: self._con_status = EDGECON_BAD @@ -1277,7 +1277,6 @@ cdef class EdgeConnection(frontend.FrontendConnection): new_ns = query_unit_group[0].ns_to_switch self.get_dbview().valid_namespace(new_ns) self.namespace = query_unit_group[0].ns_to_switch - logger.info(f'NameSpace changed to {query_unit_group[0].ns_to_switch} in current connection.') else: use_prep = ( len(query_unit_group) == 1 @@ -1453,6 +1452,11 @@ cdef class EdgeConnection(frontend.FrontendConnection): await self.recover_from_error() + try: + self.get_dbview().valid_namespace(self.namespace) + except Exception: + self.namespace = edbdef.DEFAULT_NS + else: self.buffer.finish_message() @@ -1813,6 +1817,18 @@ cdef class EdgeConnection(frontend.FrontendConnection): self._write_waiter.set_result(True) async def dump(self): + await self._dump() + + msg_buf = WriteBuffer.new_message(b'C') + msg_buf.write_int16(0) # no headers + msg_buf.write_int64(0) # capabilities + msg_buf.write_len_prefixed_bytes(b'DUMP') + msg_buf.write_bytes(sertypes.NULL_TYPE_ID.bytes) + msg_buf.write_len_prefixed_bytes(b'') + self.write(msg_buf.end_message()) + self.flush() + + async def _dump(self): cdef: WriteBuffer msg_buf dbview.DatabaseConnectionView _dbview @@ -1845,7 +1861,11 @@ cdef class EdgeConnection(frontend.FrontendConnection): # # This guarantees that every pg connection and the compiler work # with the same DB state. - user_schema = await server.introspect_user_schema(dbname, pgcon) + + global_schema = await server.introspect_global_schema(pgcon) + db_config = await server.introspect_db_config(pgcon) + dump_protocol = self.max_protocol + await pgcon.sql_execute( b'''START TRANSACTION @@ -1860,27 +1880,10 @@ cdef class EdgeConnection(frontend.FrontendConnection): SET statement_timeout = 0; ''', ) - global_schema = await server.introspect_global_schema(pgcon) - db_config = await server.introspect_db_config(pgcon) - dump_protocol = self.max_protocol - - schema_ddl, schema_dynamic_ddl, schema_ids, blocks, external_ids = ( - await compiler_pool.describe_database_dump( - user_schema, - global_schema, - db_config, - dump_protocol, - ) - ) - if schema_dynamic_ddl: - for query in schema_dynamic_ddl: - result = await pgcon.sql_fetch_val(query.encode('utf-8')) - if result: - schema_ddl += '\n' + result.decode('utf-8') + namespaces = list(_dbview.iter_ns_name()) msg_buf = WriteBuffer.new_message(b'@') - msg_buf.write_int16(4) # number of headers msg_buf.write_int16(DUMP_HEADER_BLOCK_TYPE) msg_buf.write_len_prefixed_bytes(DUMP_HEADER_BLOCK_TYPE_INFO) @@ -1888,48 +1891,73 @@ cdef class EdgeConnection(frontend.FrontendConnection): msg_buf.write_len_prefixed_utf8(str(buildmeta.get_version())) msg_buf.write_int16(DUMP_HEADER_SERVER_TIME) msg_buf.write_len_prefixed_utf8(str(int(time.time()))) - - # adding external ddl & external ids - msg_buf.write_int16(DUMP_EXTERNAL_VIEW) - external_views = await self.external_views(external_ids, pgcon) - msg_buf.write_int32(len(external_views)) - for name, view_sql in external_views: - if isinstance(name, tuple): - msg_buf.write_int16(DUMP_EXTERNAL_KEY_LINK) - msg_buf.write_len_prefixed_utf8(name[0]) - msg_buf.write_len_prefixed_utf8(name[1]) - else: - msg_buf.write_int16(DUMP_EXTERNAL_KEY_OBJ) - msg_buf.write_len_prefixed_utf8(name) - msg_buf.write_len_prefixed_utf8(view_sql) + msg_buf.write_int16(DUMP_NAMESPACE_COUNT) + msg_buf.write_int32(len(namespaces)) msg_buf.write_int16(dump_protocol[0]) msg_buf.write_int16(dump_protocol[1]) - msg_buf.write_len_prefixed_utf8(schema_ddl) + all_blocks = [] - msg_buf.write_int32(len(schema_ids)) - for (tn, td, tid) in schema_ids: - msg_buf.write_len_prefixed_utf8(tn) - msg_buf.write_len_prefixed_utf8(td) - assert len(tid) == 16 - msg_buf.write_bytes(tid) # uuid + for ns in namespaces: + user_schema = await server.introspect_user_schema(dbname, ns, pgcon) - msg_buf.write_int32(len(blocks)) - for block in blocks: - assert len(block.schema_object_id.bytes) == 16 - msg_buf.write_bytes(block.schema_object_id.bytes) # uuid - msg_buf.write_len_prefixed_bytes(block.type_desc) + schema_ddl, schema_dynamic_ddl, schema_ids, blocks, external_ids = ( + await compiler_pool.describe_database_dump( + ns, + user_schema, + global_schema, + db_config, + dump_protocol, + ) + ) - msg_buf.write_int16(len(block.schema_deps)) - for depid in block.schema_deps: - assert len(depid.bytes) == 16 - msg_buf.write_bytes(depid.bytes) # uuid + if schema_dynamic_ddl: + for query in schema_dynamic_ddl: + result = await pgcon.sql_fetch_val(query.encode('utf-8')) + if result: + schema_ddl += '\n' + result.decode('utf-8') + + all_blocks.extend(blocks) + + msg_buf.write_len_prefixed_utf8(ns) + + external_views = await self.external_views(external_ids, pgcon) + msg_buf.write_int32(len(external_views)) + for name, view_sql in external_views: + if isinstance(name, tuple): + msg_buf.write_int16(DUMP_EXTERNAL_KEY_LINK) + msg_buf.write_len_prefixed_utf8(name[0]) + msg_buf.write_len_prefixed_utf8(name[1]) + else: + msg_buf.write_int16(DUMP_EXTERNAL_KEY_OBJ) + msg_buf.write_len_prefixed_utf8(name) + msg_buf.write_len_prefixed_utf8(view_sql) + + msg_buf.write_len_prefixed_utf8(schema_ddl) + + msg_buf.write_int32(len(schema_ids)) + for (tn, td, tid) in schema_ids: + msg_buf.write_len_prefixed_utf8(tn) + msg_buf.write_len_prefixed_utf8(td) + assert len(tid) == 16 + msg_buf.write_bytes(tid) # uuid + + msg_buf.write_int32(len(blocks)) + for block in blocks: + assert len(block.schema_object_id.bytes) == 16 + msg_buf.write_bytes(block.schema_object_id.bytes) # uuid + msg_buf.write_len_prefixed_bytes(block.type_desc) + + msg_buf.write_int16(len(block.schema_deps)) + for depid in block.schema_deps: + assert len(depid.bytes) == 16 + msg_buf.write_bytes(depid.bytes) # uuid self._transport.write(memoryview(msg_buf.end_message())) self.flush() - blocks_queue = collections.deque(blocks) + blocks_queue = collections.deque(all_blocks) output_queue = asyncio.Queue(maxsize=2) async with taskgroup.TaskGroup() as g: @@ -1978,15 +2006,6 @@ cdef class EdgeConnection(frontend.FrontendConnection): self._in_dump_restore = False server.release_pgcon(dbname, pgcon) - msg_buf = WriteBuffer.new_message(b'C') - msg_buf.write_int16(0) # no headers - msg_buf.write_int64(0) # capabilities - msg_buf.write_len_prefixed_bytes(b'DUMP') - msg_buf.write_bytes(sertypes.NULL_TYPE_ID.bytes) - msg_buf.write_len_prefixed_bytes(b'') - self.write(msg_buf.end_message()) - self.flush() - async def external_views(self, external_ids: List[Tuple[str, str]], pgcon): views = [] for ext_name, ext_id in external_ids: @@ -2027,7 +2046,75 @@ cdef class EdgeConnection(frontend.FrontendConnection): else: _dbview.on_success(query_unit, {}) + def restore_external_views(self): + external_views = [] + external_view_num = self.buffer.read_int32() + for _ in range(external_view_num): + key_flag = self.buffer.read_int16() + if key_flag == DUMP_EXTERNAL_KEY_LINK: + obj_name = self.buffer.read_len_prefixed_utf8() + link_name = self.buffer.read_len_prefixed_utf8() + sql = self.buffer.read_len_prefixed_utf8() + external_views.append(((obj_name, link_name), sql)) + else: + name = self.buffer.read_len_prefixed_utf8() + sql = self.buffer.read_len_prefixed_utf8() + external_views.append((name, sql)) + return external_views + + def restore_schema_info(self, external_views=None): + if external_views: + external_views = external_views + else: + external_views = self.restore_external_views() + + schema_ddl = self.buffer.read_len_prefixed_bytes() + + ids_num = self.buffer.read_int32() + schema_ids = [] + for _ in range(ids_num): + schema_ids.append( + ( + self.buffer.read_len_prefixed_utf8(), + self.buffer.read_len_prefixed_utf8(), + self.buffer.read_bytes(16), + ) + ) + + block_num = self.buffer.read_int32() + blocks = [] + for _ in range(block_num): + blocks.append( + ( + self.buffer.read_bytes(16), + self.buffer.read_len_prefixed_bytes(), + ) + ) + + # Ignore deps info + for _ in range(self.buffer.read_int16()): + self.buffer.read_bytes(16) + + return RestoreSchemaInfo( + schema_ddl=schema_ddl, schema_ids=schema_ids, blocks=blocks, external_views=external_views + ) + + async def restore(self): + await self._restore() + + state_tid, state_data = self.get_dbview().encode_state() + + msg = WriteBuffer.new_message(b'C') + msg.write_int16(0) # no headers + msg.write_int64(0) # capabilities + msg.write_len_prefixed_bytes(b'RESTORE') + msg.write_bytes(state_tid.bytes) + msg.write_len_prefixed_bytes(state_data) + self.write(msg.end_message()) + self.flush() + + async def _restore(self): cdef: WriteBuffer msg_buf char mtype @@ -2046,33 +2133,26 @@ cdef class EdgeConnection(frontend.FrontendConnection): server = self.server compiler_pool = server.get_compiler_pool() - # TODO add namespace global_schema = _dbview.get_global_schema() - user_schema = _dbview.get_user_schema(edbdef.DEFAULT_NS) dump_server_ver_str = None headers_num = self.buffer.read_int16() - external_views = [] + ns_count = 0 + schema_info_by_ns: Dict[str, RestoreSchemaInfo] = {} + external_views=[] + default_ns = edbdef.DEFAULT_NS + for _ in range(headers_num): hdrname = self.buffer.read_int16() - if hdrname != DUMP_EXTERNAL_VIEW: + + if hdrname not in [DUMP_EXTERNAL_VIEW, DUMP_NAMESPACE_COUNT]: hdrval = self.buffer.read_len_prefixed_bytes() if hdrname == DUMP_HEADER_SERVER_VER: dump_server_ver_str = hdrval.decode('utf-8') - # getting external ddl & external ids - if hdrname == DUMP_EXTERNAL_VIEW: - external_view_num = self.buffer.read_int32() - for _ in range(external_view_num): - key_flag = self.buffer.read_int16() - if key_flag == DUMP_EXTERNAL_KEY_LINK: - obj_name = self.buffer.read_len_prefixed_utf8() - link_name = self.buffer.read_len_prefixed_utf8() - sql = self.buffer.read_len_prefixed_utf8() - external_views.append(((obj_name, link_name), sql)) - else: - name = self.buffer.read_len_prefixed_utf8() - sql = self.buffer.read_len_prefixed_utf8() - external_views.append((name, sql)) + elif hdrname == DUMP_EXTERNAL_VIEW: + external_views = self.restore_external_views() + elif hdrname == DUMP_NAMESPACE_COUNT: + ns_count = self.buffer.read_int32() proto_major = self.buffer.read_int16() proto_minor = self.buffer.read_int16() @@ -2081,36 +2161,30 @@ cdef class EdgeConnection(frontend.FrontendConnection): raise errors.ProtocolError( f'unsupported dump version {proto_major}.{proto_minor}') - schema_ddl = self.buffer.read_len_prefixed_bytes() - - ids_num = self.buffer.read_int32() - schema_ids = [] - for _ in range(ids_num): - schema_ids.append(( - self.buffer.read_len_prefixed_utf8(), - self.buffer.read_len_prefixed_utf8(), - self.buffer.read_bytes(16), - )) - - block_num = self.buffer.read_int32() - blocks = [] - for _ in range(block_num): - blocks.append(( - self.buffer.read_bytes(16), - self.buffer.read_len_prefixed_bytes(), - )) - - # Ignore deps info - for _ in range(self.buffer.read_int16()): - self.buffer.read_bytes(16) + if ns_count > 0: + for _ in range(ns_count): + ns = self.buffer.read_len_prefixed_utf8() + schema_info_by_ns[ns] = self.restore_schema_info() + else: + schema_info_by_ns[default_ns] = self.restore_schema_info(external_views=external_views) self.buffer.finish_message() dbname = _dbview.dbname pgcon = await server.acquire_pgcon(dbname) self._in_dump_restore = True + + for ns in schema_info_by_ns: + if ns == edbdef.DEFAULT_NS: + continue + await server.create_namespace(pgcon, ns) + await self._execute_utility_stmt(f'CREATE NAMESPACE {ns}', pgcon) + await server.introspect(dbname, ns) + try: - _dbview.decode_state(sertypes.NULL_TYPE_ID.bytes, b'', edbdef.DEFAULT_NS) + all_restore_blocks = [] + all_tables = [] + await self._execute_utility_stmt( 'START TRANSACTION ISOLATION SERIALIZABLE', pgcon, @@ -2125,56 +2199,63 @@ cdef class EdgeConnection(frontend.FrontendConnection): SET statement_timeout = 0; ''', ) - - schema_sql_units, restore_blocks, tables = \ - await compiler_pool.describe_database_restore( - user_schema, - global_schema, - dump_server_ver_str, - schema_ddl, - schema_ids, - blocks, - proto, - dict(external_views) - ) - - for query_unit in schema_sql_units: - new_types = None - _dbview.start(query_unit) - - try: - if query_unit.config_ops: - for op in query_unit.config_ops: - if op.scope is config.ConfigScope.INSTANCE: - raise errors.ProtocolError( - 'CONFIGURE INSTANCE cannot be executed' - ' in dump restore' - ) - - if query_unit.sql: - if query_unit.ddl_stmt_id: - ddl_ret = await pgcon.run_ddl(query_unit) - if ddl_ret and ddl_ret['new_types']: - new_types = ddl_ret['new_types'] - if query_unit.schema_refl_sqls: - # no performance optimization - await pgcon.sql_execute(query_unit.schema_refl_sqls) - else: - await pgcon.sql_execute(query_unit.sql) - except Exception: - _dbview.on_error() - raise - else: - _dbview.on_success(query_unit, new_types) + for ns, (schema_ddl, schema_ids, blocks, external_views) in schema_info_by_ns.items(): + logger.info(ns) + user_schema = _dbview.get_user_schema(ns) + _dbview.decode_state(sertypes.NULL_TYPE_ID.bytes, b'', ns) + + schema_sql_units, restore_blocks, tables = \ + await compiler_pool.describe_database_restore( + ns, + user_schema, + global_schema, + dump_server_ver_str, + schema_ddl, + schema_ids, + blocks, + proto, + dict(external_views) + ) + all_restore_blocks.extend(restore_blocks) + all_tables.extend(tables) + + for query_unit in schema_sql_units: + new_types = None + _dbview.start(query_unit) + + try: + if query_unit.config_ops: + for op in query_unit.config_ops: + if op.scope is config.ConfigScope.INSTANCE: + raise errors.ProtocolError( + 'CONFIGURE INSTANCE cannot be executed' + ' in dump restore' + ) + + if query_unit.sql: + if query_unit.ddl_stmt_id: + ddl_ret = await pgcon.run_ddl(query_unit) + if ddl_ret and ddl_ret['new_types']: + new_types = ddl_ret['new_types'] + if query_unit.schema_refl_sqls: + # no performance optimization + await pgcon.sql_execute(query_unit.schema_refl_sqls) + else: + await pgcon.sql_execute(query_unit.sql) + except Exception: + _dbview.on_error() + raise + else: + _dbview.on_success(query_unit, new_types) restore_blocks = { b.schema_object_id: b - for b in restore_blocks + for b in all_restore_blocks } disable_trigger_q = '' enable_trigger_q = '' - for table in tables: + for table in all_tables: disable_trigger_q += ( f'ALTER TABLE {table} DISABLE TRIGGER ALL;' ) @@ -2252,20 +2333,9 @@ cdef class EdgeConnection(frontend.FrontendConnection): await server.introspect(dbname) - # TODO add namespace - if _dbview.is_state_desc_changed(edbdef.DEFAULT_NS): - self.write(self.make_state_data_description_msg()) - - state_tid, state_data = _dbview.encode_state() - - msg = WriteBuffer.new_message(b'C') - msg.write_int16(0) # no headers - msg.write_int64(0) # capabilities - msg.write_len_prefixed_bytes(b'RESTORE') - msg.write_bytes(state_tid.bytes) - msg.write_len_prefixed_bytes(state_data) - self.write(msg.end_message()) - self.flush() + for ns in schema_info_by_ns: + if _dbview.is_state_desc_changed(ns): + self.write(self.make_state_data_description_msg()) def _build_type_id_map_for_restore_mending(self, restore_block): type_map = {} @@ -2282,6 +2352,7 @@ cdef class EdgeConnection(frontend.FrontendConnection): type_map[desc.schema_type_id] = ( self.get_dbview().resolve_backend_type_id( desc.schema_type_id, + self.namespace ) ) @@ -2389,7 +2460,7 @@ async def run_script( EdgeConnection conn dbview.CompiledQuery compiled conn = new_edge_connection(server) - await conn._start_connection(database) + await conn._start_connection(database, edbdef.DEFAULT_NS) try: compiled = await conn.get_dbview().parse( dbview.QueryRequestInfo( diff --git a/edb/server/protocol/binary_v0.pyx b/edb/server/protocol/binary_v0.pyx index 477724a7e20..673baf23c5a 100644 --- a/edb/server/protocol/binary_v0.pyx +++ b/edb/server/protocol/binary_v0.pyx @@ -267,170 +267,7 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): self.flush() async def legacy_dump(self): - cdef: - WriteBuffer msg_buf - dbview.DatabaseConnectionView _dbview - - self.reject_headers() - self.buffer.finish_message() - - _dbview = self.get_dbview() - if _dbview.txid: - raise errors.ProtocolError( - 'DUMP must not be executed while in transaction' - ) - - server = self.server - compiler_pool = server.get_compiler_pool() - - dbname = _dbview.dbname - pgcon = await server.acquire_pgcon(dbname) - self._in_dump_restore = True - try: - # To avoid having races, we want to: - # - # 1. start a transaction; - # - # 2. in the compiler process we connect to that transaction - # and re-introspect the schema in it. - # - # 3. all dump worker pg connection would work on the same - # connection. - # - # This guarantees that every pg connection and the compiler work - # with the same DB state. - - await pgcon.sql_execute( - b'''START TRANSACTION - ISOLATION LEVEL SERIALIZABLE - READ ONLY - DEFERRABLE; - - -- Disable transaction or query execution timeout - -- limits. Both clients and the server can be slow - -- during the dump/restore process. - SET idle_in_transaction_session_timeout = 0; - SET statement_timeout = 0; - ''', - ) - # TODO add namespace - user_schema = await server.introspect_user_schema(dbname, conn=pgcon) - global_schema = await server.introspect_global_schema(pgcon) - db_config = await server.introspect_db_config(pgcon) - dump_protocol = self.max_protocol - - schema_ddl, schema_dynamic_ddl, schema_ids, blocks, external_ids = ( - await compiler_pool.describe_database_dump( - user_schema, - global_schema, - db_config, - dump_protocol, - ) - ) - - if schema_dynamic_ddl: - for query in schema_dynamic_ddl: - result = await pgcon.sql_fetch_val(query.encode('utf-8')) - if result: - schema_ddl += '\n' + result.decode('utf-8') - - msg_buf = WriteBuffer.new_message(b'@') - - msg_buf.write_int16(4) # number of headers - msg_buf.write_int16(DUMP_HEADER_BLOCK_TYPE) - msg_buf.write_len_prefixed_bytes(DUMP_HEADER_BLOCK_TYPE_INFO) - msg_buf.write_int16(DUMP_HEADER_SERVER_VER) - msg_buf.write_len_prefixed_utf8(str(buildmeta.get_version())) - msg_buf.write_int16(DUMP_HEADER_SERVER_TIME) - msg_buf.write_len_prefixed_utf8(str(int(time.time()))) - - # adding external ddl & external ids - msg_buf.write_int16(DUMP_EXTERNAL_VIEW) - external_views = await self.external_views(external_ids, pgcon) - msg_buf.write_int32(len(external_views)) - for name, view_sql in external_views: - if isinstance(name, tuple): - msg_buf.write_int16(DUMP_EXTERNAL_KEY_LINK) - msg_buf.write_len_prefixed_utf8(name[0]) - msg_buf.write_len_prefixed_utf8(name[1]) - else: - msg_buf.write_int16(DUMP_EXTERNAL_KEY_OBJ) - msg_buf.write_len_prefixed_utf8(name) - msg_buf.write_len_prefixed_utf8(view_sql) - - msg_buf.write_int16(dump_protocol[0]) - msg_buf.write_int16(dump_protocol[1]) - msg_buf.write_len_prefixed_utf8(schema_ddl) - - msg_buf.write_int32(len(schema_ids)) - for (tn, td, tid) in schema_ids: - msg_buf.write_len_prefixed_utf8(tn) - msg_buf.write_len_prefixed_utf8(td) - assert len(tid) == 16 - msg_buf.write_bytes(tid) # uuid - - msg_buf.write_int32(len(blocks)) - for block in blocks: - assert len(block.schema_object_id.bytes) == 16 - msg_buf.write_bytes(block.schema_object_id.bytes) # uuid - msg_buf.write_len_prefixed_bytes(block.type_desc) - - msg_buf.write_int16(len(block.schema_deps)) - for depid in block.schema_deps: - assert len(depid.bytes) == 16 - msg_buf.write_bytes(depid.bytes) # uuid - - self._transport.write(memoryview(msg_buf.end_message())) - self.flush() - - blocks_queue = collections.deque(blocks) - output_queue = asyncio.Queue(maxsize=2) - - async with taskgroup.TaskGroup() as g: - g.create_task(pgcon.dump( - blocks_queue, - output_queue, - DUMP_BLOCK_SIZE, - )) - - nstops = 0 - while True: - if self._cancelled: - raise ConnectionAbortedError - - out = await output_queue.get() - if out is None: - nstops += 1 - if nstops == 1: - # we only have one worker right now - break - else: - block, block_num, data = out - - msg_buf = WriteBuffer.new_message(b'=') - msg_buf.write_int16(4) # number of headers - - msg_buf.write_int16(DUMP_HEADER_BLOCK_TYPE) - msg_buf.write_len_prefixed_bytes( - DUMP_HEADER_BLOCK_TYPE_DATA) - msg_buf.write_int16(DUMP_HEADER_BLOCK_ID) - msg_buf.write_len_prefixed_bytes( - block.schema_object_id.bytes) - msg_buf.write_int16(DUMP_HEADER_BLOCK_NUM) - msg_buf.write_len_prefixed_bytes( - str(block_num).encode()) - msg_buf.write_int16(DUMP_HEADER_BLOCK_DATA) - msg_buf.write_len_prefixed_buffer(data) - - self._transport.write(memoryview(msg_buf.end_message())) - if self._write_waiter: - await self._write_waiter - - await pgcon.sql_execute(b"ROLLBACK;") - - finally: - self._in_dump_restore = False - server.release_pgcon(dbname, pgcon) + await self._dump() msg_buf = WriteBuffer.new_message(b'C') msg_buf.write_int16(0) # no headers @@ -439,230 +276,7 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): self.flush() async def legacy_restore(self): - cdef: - WriteBuffer msg_buf - char mtype - dbview.DatabaseConnectionView _dbview - - _dbview = self.get_dbview() - if _dbview.txid: - raise errors.ProtocolError( - 'RESTORE must not be executed while in transaction' - ) - - self.reject_headers() - self.buffer.read_int16() # discard -j level - - # Now parse the embedded dump header message: - - server = self.server - compiler_pool = server.get_compiler_pool() - # TODO add namespace - global_schema = _dbview.get_global_schema() - user_schema = _dbview.get_user_schema(defines.DEFAULT_NS) - - dump_server_ver_str = None - headers_num = self.buffer.read_int16() - external_views = [] - for _ in range(headers_num): - hdrname = self.buffer.read_int16() - if hdrname != DUMP_EXTERNAL_VIEW: - hdrval = self.buffer.read_len_prefixed_bytes() - if hdrname == DUMP_HEADER_SERVER_VER: - dump_server_ver_str = hdrval.decode('utf-8') - # getting external ddl & external ids - if hdrname == DUMP_EXTERNAL_VIEW: - external_view_num = self.buffer.read_int32() - for _ in range(external_view_num): - key_flag = self.buffer.read_int16() - if key_flag == DUMP_EXTERNAL_KEY_LINK: - obj_name = self.buffer.read_len_prefixed_utf8() - link_name = self.buffer.read_len_prefixed_utf8() - sql = self.buffer.read_len_prefixed_utf8() - external_views.append(((obj_name, link_name), sql)) - else: - name = self.buffer.read_len_prefixed_utf8() - sql = self.buffer.read_len_prefixed_utf8() - external_views.append((name, sql)) - - proto_major = self.buffer.read_int16() - proto_minor = self.buffer.read_int16() - proto = (proto_major, proto_minor) - if proto > DUMP_VER_MAX or proto < DUMP_VER_MIN: - raise errors.ProtocolError( - f'unsupported dump version {proto_major}.{proto_minor}') - - schema_ddl = self.buffer.read_len_prefixed_bytes() - - ids_num = self.buffer.read_int32() - schema_ids = [] - for _ in range(ids_num): - schema_ids.append(( - self.buffer.read_len_prefixed_utf8(), - self.buffer.read_len_prefixed_utf8(), - self.buffer.read_bytes(16), - )) - - block_num = self.buffer.read_int32() - blocks = [] - for _ in range(block_num): - blocks.append(( - self.buffer.read_bytes(16), - self.buffer.read_len_prefixed_bytes(), - )) - - # Ignore deps info - for _ in range(self.buffer.read_int16()): - self.buffer.read_bytes(16) - - self.buffer.finish_message() - dbname = _dbview.dbname - pgcon = await server.acquire_pgcon(dbname) - - self._in_dump_restore = True - try: - # TODO add namespace - _dbview.decode_state(sertypes.NULL_TYPE_ID.bytes, b'', edbdef.DEFAULT_NS) - await self._execute_utility_stmt( - 'START TRANSACTION ISOLATION SERIALIZABLE', - pgcon, - ) - - await pgcon.sql_execute( - b''' - -- Disable transaction or query execution timeout - -- limits. Both clients and the server can be slow - -- during the dump/restore process. - SET idle_in_transaction_session_timeout = 0; - SET statement_timeout = 0; - ''', - ) - - schema_sql_units, restore_blocks, tables = \ - await compiler_pool.describe_database_restore( - user_schema, - global_schema, - dump_server_ver_str, - schema_ddl, - schema_ids, - blocks, - proto, - dict(external_views) - ) - - for query_unit in schema_sql_units: - new_types = None - _dbview.start(query_unit) - - try: - if query_unit.config_ops: - for op in query_unit.config_ops: - if op.scope is config.ConfigScope.INSTANCE: - raise errors.ProtocolError( - 'CONFIGURE INSTANCE cannot be executed' - ' in dump restore' - ) - - if query_unit.sql: - if query_unit.ddl_stmt_id: - ddl_ret = await pgcon.run_ddl(query_unit) - if ddl_ret and ddl_ret['new_types']: - new_types = ddl_ret['new_types'] - if query_unit.schema_refl_sqls: - # no performance optimization - await pgcon.sql_execute(query_unit.schema_refl_sqls) - else: - await pgcon.sql_execute(query_unit.sql) - except Exception: - _dbview.on_error() - raise - else: - _dbview.on_success(query_unit, new_types) - - restore_blocks = { - b.schema_object_id: b - for b in restore_blocks - } - - disable_trigger_q = '' - enable_trigger_q = '' - for table in tables: - disable_trigger_q += ( - f'ALTER TABLE {table} DISABLE TRIGGER ALL;' - ) - enable_trigger_q += ( - f'ALTER TABLE {table} ENABLE TRIGGER ALL;' - ) - - await pgcon.sql_execute(disable_trigger_q.encode()) - - # Send "RestoreReadyMessage" - msg = WriteBuffer.new_message(b'+') - msg.write_int16(0) # no headers - msg.write_int16(1) # -j1 - self.write(msg.end_message()) - self.flush() - - while True: - if not self.buffer.take_message(): - # Don't report idling when restoring a dump. - # This is an edge case and the client might be - # legitimately slow. - await self.wait_for_message(report_idling=False) - mtype = self.buffer.get_message_type() - - if mtype == b'=': - block_type = None - block_id = None - block_num = None - block_data = None - - num_headers = self.buffer.read_int16() - for _ in range(num_headers): - header = self.buffer.read_int16() - if header == DUMP_HEADER_BLOCK_TYPE: - block_type = self.buffer.read_len_prefixed_bytes() - elif header == DUMP_HEADER_BLOCK_ID: - block_id = self.buffer.read_len_prefixed_bytes() - block_id = pg_UUID(block_id) - elif header == DUMP_HEADER_BLOCK_NUM: - block_num = self.buffer.read_len_prefixed_bytes() - elif header == DUMP_HEADER_BLOCK_DATA: - block_data = self.buffer.read_len_prefixed_bytes() - - self.buffer.finish_message() - - if (block_type is None or block_id is None - or block_num is None or block_data is None): - raise errors.ProtocolError('incomplete data block') - - restore_block = restore_blocks[block_id] - type_id_map = self._build_type_id_map_for_restore_mending( - restore_block) - await pgcon.restore(restore_block, block_data, type_id_map) - - elif mtype == b'.': - self.buffer.finish_message() - break - - else: - self.fallthrough() - - await pgcon.sql_execute(enable_trigger_q.encode()) - - except Exception: - await pgcon.sql_execute(b'ROLLBACK') - _dbview.abort_tx() - raise - - else: - await self._execute_utility_stmt('COMMIT', pgcon) - - finally: - self._in_dump_restore = False - server.release_pgcon(dbname, pgcon) - - await server.introspect(dbname) + await self._restore() msg = WriteBuffer.new_message(b'C') msg.write_int16(0) # no headers diff --git a/edb/server/protocol/consts.pxi b/edb/server/protocol/consts.pxi index 09cc41e1eee..b02acf02800 100644 --- a/edb/server/protocol/consts.pxi +++ b/edb/server/protocol/consts.pxi @@ -27,6 +27,8 @@ DEF DUMP_HEADER_SERVER_TIME = 102 DEF DUMP_HEADER_SERVER_VER = 103 DEF DUMP_HEADER_BLOCKS_INFO = 104 DEF DUMP_EXTERNAL_VIEW = 105 +DEF DUMP_NAMESPACE_COUNT = 106 +DEF DUMP_NAMESPACE_NAME = 107 DEF DUMP_HEADER_BLOCK_ID = 110 DEF DUMP_HEADER_BLOCK_NUM = 111 diff --git a/edb/server/protocol/infer_expr.py b/edb/server/protocol/infer_expr.py index 9fd35fae996..403a5ab4dd7 100644 --- a/edb/server/protocol/infer_expr.py +++ b/edb/server/protocol/infer_expr.py @@ -113,6 +113,11 @@ async def handle_request( async def execute(db, server, namespace: str, module: str, objname: str, expression: str): + if namespace not in db.ns_map: + raise errors.InternalServerError( + f'NameSpace: [{namespace}] not in current db [{db.name}](ver:{db.dbver})' + ) + ns = db.ns_map[namespace] dbver = db.dbver query_cache = server._http_query_cache @@ -129,9 +134,9 @@ async def execute(db, server, namespace: str, module: str, objname: str, express result = await compiler_pool.infer_expr( db.name, namespace, - db.user_schema, + ns.user_schema, server.get_global_schema(), - db.reflection_cache, + ns.reflection_cache, db.db_config, server.get_compilation_system_config(), name_str, diff --git a/edb/server/server.py b/edb/server/server.py index 5364bb9c36f..08991a112c4 100644 --- a/edb/server/server.py +++ b/edb/server/server.py @@ -80,15 +80,26 @@ log_metrics = logging.getLogger('edb.server.metrics') _RE_BYTES_REPL_NS = re.compile( r'(current_setting\([\']+)?' - r'(edgedb)(\.|instdata|pub|ss|std|;)', + r'(edgedb)(\.|instdata|pub\.|pub;|pub\'|ss|std\.|std\'|std;|;)([\"a-z0-9_\-]+)?', ) def repl_ignore_setting(match_obj): - maybe_setting, to_repl, tailing = match_obj.groups() + maybe_setting, schema_name, tailing, maybe_domain_name = match_obj.groups() if maybe_setting: - return maybe_setting + to_repl + tailing - return "{ns_prefix}" + to_repl + tailing + return maybe_setting + schema_name + tailing + (maybe_domain_name or '') + if maybe_domain_name: + # skip create type/domain in builtin: + # Type ends with '_t' in edgedb + # Type ends with '_t' in edgedbpub + # Domain ends with '_domain' in edgedbstd + if ( + (tailing == '.' and maybe_domain_name.strip('"').endswith('_t')) + or (tailing == 'pub.' and maybe_domain_name.strip('"').endswith('_t')) + or (tailing == 'std.' and maybe_domain_name.strip('"').endswith('_domain')) + ): + return schema_name + tailing + maybe_domain_name + return "{ns_prefix}" + schema_name + tailing + (maybe_domain_name or '') class RoleDescriptor(TypedDict): @@ -950,7 +961,7 @@ async def _load_instance_data(self): ''') self._local_intro_query = _RE_BYTES_REPL_NS.sub( - r"{ns_prefix}\2\3", + repl_ignore_setting, local_intro_query.decode('utf-8'), ) diff --git a/edb/testbase/http.py b/edb/testbase/http.py index 7e0d855826d..a277f6d2471 100644 --- a/edb/testbase/http.py +++ b/edb/testbase/http.py @@ -22,6 +22,7 @@ import contextlib import http.client import json +import os import ssl import urllib.parse import urllib.request @@ -34,6 +35,7 @@ from . import server from .server import PGConnMixin +from edb.server import defines class StubbornHttpConnection(http.client.HTTPSConnection): @@ -148,7 +150,8 @@ def edgeql_query( variables=None, globals=None, module=None, limit=None ): req_data = { - 'query': query + 'query': query, + 'namespace': os.environ.get('EDGEDB_TEST_CASES_NAMESPACE', defines.DEFAULT_NS) } if use_http_post: @@ -226,7 +229,8 @@ def graphql_query(self, query, *, operation_name=None, variables=None, globals=None): req_data = { - 'query': query + 'query': query, + 'namespace': self.test_ns } if operation_name is not None: @@ -319,7 +323,8 @@ def infer_expr(self, objname, module, expression): req_data = { 'object': objname, 'module': module, - 'expression': expression + 'expression': expression, + 'namespace': os.environ.get('EDGEDB_TEST_CASES_NAMESPACE', defines.DEFAULT_NS) } req = urllib.request.Request(self.http_addr, method='POST') @@ -373,7 +378,7 @@ def get_api_path(cls): def create_type(self, body): req_data = body.as_dict() - + req_data['namespace'] = os.environ.get('EDGEDB_TEST_CASES_NAMESPACE', defines.DEFAULT_NS) req = urllib.request.Request(self.http_addr, method='POST') req.add_header('Content-Type', 'application/json') req.add_header('testmode', '1') diff --git a/edb/testbase/server.py b/edb/testbase/server.py index b2fd415e6dc..7215a48e32b 100644 --- a/edb/testbase/server.py +++ b/edb/testbase/server.py @@ -48,7 +48,7 @@ from edb.edgeql import quote as qlquote from edb.pgsql import common as pgcommon from edb.pgsql import params as pgparams -from edb.server import args as edgedb_args, pgcon, pgconnparams +from edb.server import args as edgedb_args, pgcon, pgconnparams, defines from edb.server import cluster as edgedb_cluster from edb.server import defines as edgedb_defines from edb.server import main as edgedb_main @@ -1089,11 +1089,6 @@ async def create_db(): cls.con = cls.loop.run_until_complete(cls.connect(database=dbname)) - if class_set_up != 'skip': - script = cls.get_setup_script() - if script: - cls.loop.run_until_complete(cls.con.execute(script)) - @classmethod def tearDownClass(cls): script = '' @@ -1259,6 +1254,39 @@ def shape(self): class BaseQueryTestCase(DatabaseTestCase): BASE_TEST_CLASS = True + test_ns: str + + @classmethod + def setUpClass(cls): + super().setUpClass() + class_set_up = os.environ.get('EDGEDB_TEST_CASES_SET_UP', 'run') + cls.test_ns = os.environ.get('EDGEDB_TEST_CASES_NAMESPACE', defines.DEFAULT_NS) + + if class_set_up != 'skip': + if cls.test_ns != defines.DEFAULT_NS: + cls.loop.run_until_complete( + cls.con.execute(f'CREATE NAMESPACE {cls.test_ns}') + ) + cls.loop.run_until_complete( + cls.con.execute(f'use namespace {cls.test_ns}') + ) + script = cls.get_setup_script() + if script: + cls.loop.run_until_complete(cls.con.execute(script)) + + def setUp(self): + if self.test_ns != defines.DEFAULT_NS: + self.loop.run_until_complete( + self.con.execute(f'use namespace {self.test_ns}') + ) + self.loop.run_until_complete( + self.assert_query_result( + r'show namespace', + [self.test_ns] + ) + ) + + super().setUp() class DDLTestCase(BaseQueryTestCase): diff --git a/tests/test_edgeql_data_migration.py b/tests/test_edgeql_data_migration.py index e980be0eb0b..b1473bfd57b 100644 --- a/tests/test_edgeql_data_migration.py +++ b/tests/test_edgeql_data_migration.py @@ -11224,6 +11224,7 @@ async def test_edgeql_migration_recovery_in_script(self): async def test_edgeql_migration_recovery_commit_fail(self): con2 = await self.connect(database=self.con.dbname) try: + await con2.execute(f'USE NAMESPACE {self.test_ns};') await con2.execute('START MIGRATION TO {}') await con2.execute('POPULATE MIGRATION') diff --git a/tests/test_http_graphql_query.py b/tests/test_http_graphql_query.py index 7602bb18b2f..18cdb8e22c2 100644 --- a/tests/test_http_graphql_query.py +++ b/tests/test_http_graphql_query.py @@ -50,7 +50,8 @@ def test_graphql_http_keepalive_01(self): value } } - ''' + ''', + 'namespace': self.test_ns } data, headers, status = self.http_con_request(con, req1_data) self.assertEqual(status, 200) diff --git a/tests/test_namespace.py b/tests/test_namespace.py new file mode 100644 index 00000000000..6cf9515d32a --- /dev/null +++ b/tests/test_namespace.py @@ -0,0 +1,137 @@ +import edgedb + +from edb.schema import defines as s_def +from edb.testbase import server as tb + + +class TestNameSpace(tb.DatabaseTestCase): + TRANSACTION_ISOLATION = False + + async def test_create_drop_namespace(self): + await self.con.execute("create namespace ns1;") + await self.assert_query_result( + r"select sys::NameSpace{name} order by .name", + [{'name': s_def.DEFAULT_NS}, {'name': 'ns1'}] + ) + await self.con.execute("drop namespace ns1;") + await self.assert_query_result( + r"select sys::NameSpace{name} order by .name", + [{'name': s_def.DEFAULT_NS}] + ) + + async def test_create_namespace_invalid(self): + with self.assertRaisesRegex( + edgedb.SchemaDefinitionError, + f'NameSpace names can not be started with \'pg_\', ' + f'as such names are reserved for system schemas', + ): + await self.con.execute("create namespace pg_ns1;") + + with self.assertRaisesRegex( + edgedb.SchemaDefinitionError, + f'\'{s_def.DEFAULT_NS}\' is reserved as name for ' + f'default namespace, use others instead.' + ): + await self.con.execute(f"create namespace {s_def.DEFAULT_NS};") + + async def test_create_namespace_exists(self): + await self.con.execute("create namespace ns2;") + + with self.assertRaisesRegex( + edgedb.EdgeDBError, + 'namespace "ns2" already exists', + ): + await self.con.execute("create namespace ns2;") + + await self.con.execute("drop namespace ns2;") + + async def test_drop_namespace_invalid(self): + with self.assertRaisesRegex( + edgedb.EdgeDBError, + 'namespace "ns3" does not exist', + ): + await self.con.execute("drop namespace ns3;") + + with self.assertRaisesRegex( + edgedb.ExecutionError, + f"namespace '{s_def.DEFAULT_NS}' cannot be dropped", + ): + await self.con.execute(f"drop namespace {s_def.DEFAULT_NS};") + + await self.con.execute("create namespace n1;") + await self.con.execute("use namespace n1;") + with self.assertRaisesRegex( + edgedb.ExecutionError, + f"cannot drop the currently open current_namespace 'n1'", + ): + await self.con.execute(f"drop namespace n1;") + + async def test_use_show_namespace(self): + await self.con.execute("create namespace temp1;") + # check default + conn1 = await self.connect(database=self.get_database_name()) + conn2 = await self.connect(database=self.get_database_name()) + try: + self.assertEqual((await conn2.query('show namespace;')), [s_def.DEFAULT_NS]) + self.assertEqual((await conn1.query('show namespace;')), [s_def.DEFAULT_NS]) + + # check seperated between connection + await conn1.execute('use namespace temp1;') + self.assertEqual((await conn1.query('show namespace;')), ['temp1']) + self.assertEqual((await conn2.query('show namespace;')), [s_def.DEFAULT_NS]) + + # check use + await conn1.execute('CONFIGURE SESSION SET __internal_testmode := true;' + 'create type A;' + 'CONFIGURE SESSION SET __internal_testmode := false;') + self.assertEqual( + ( + await conn1.query( + 'select count((select schema::ObjectType filter .name="default::A"))' + ) + ), + [1] + ) + self.assertEqual( + ( + await conn2.query( + 'select count((select schema::ObjectType filter .name="default::A"))' + ) + ), + [0] + ) + + await conn2.execute('drop namespace temp1;') + + with self.assertRaises(edgedb.QueryError): + await conn1.query("select 1") + + self.assertEqual((await conn1.query('show namespace;')), ['default']) + finally: + await conn1.aclose() + await conn2.aclose() + + async def test_use_namespace_invalid(self): + await self.con.execute("create namespace ns4;") + try: + with self.assertRaises(edgedb.QueryError): + await self.con.execute("use namespace ns5;") + + with self.assertRaisesRegex( + edgedb.ProtocolError, + 'USE NAMESPACE statement is not allowed to be used in script.', + ): + await self.con.execute("use namespace ns4;select 1;") + + await self.con.execute("START TRANSACTION") + + with self.assertRaisesRegex( + edgedb.ProtocolError, + 'USE NAMESPACE statement is not allowed to be used in transaction.', + ): + await self.con.execute("use namespace ns4;") + + await self.con.execute("ROLLBACK") + + finally: + await self.con.execute("drop namespace ns4;") From 9f8d910122bc8375f03fcd083be33b94b386ca10 Mon Sep 17 00:00:00 2001 From: JMPRIEST Date: Tue, 30 May 2023 17:27:12 +0800 Subject: [PATCH 16/20] =?UTF-8?q?:bug:=20=E4=BF=AE=E6=AD=A3=E5=85=BC?= =?UTF-8?q?=E5=AE=B9dp/2.1=E7=9A=84external=5Fviews=E4=B8=8D=E5=AD=98?= =?UTF-8?q?=E5=9C=A8=E6=83=85=E5=86=B5=E6=97=B6=E7=9A=84=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- edb/server/protocol/binary.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/edb/server/protocol/binary.pyx b/edb/server/protocol/binary.pyx index ef0aeeebd4e..9b9bcac72c2 100644 --- a/edb/server/protocol/binary.pyx +++ b/edb/server/protocol/binary.pyx @@ -2063,7 +2063,7 @@ cdef class EdgeConnection(frontend.FrontendConnection): return external_views def restore_schema_info(self, external_views=None): - if external_views: + if external_views is not None: external_views = external_views else: external_views = self.restore_external_views() From 653384fdba52c4f794990ecfe945ad54912d1303 Mon Sep 17 00:00:00 2001 From: JMPRIEST Date: Wed, 31 May 2023 16:00:55 +0800 Subject: [PATCH 17/20] =?UTF-8?q?:white=5Fcheck=5Fmark:=20=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0namespace=E5=B9=B6=E5=8F=91DDL=E6=83=85=E5=86=B5?= =?UTF-8?q?=E7=9A=84=E6=B5=8B=E8=AF=95=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_namespace.py | 195 +++++++++++++++++++++++++++++++++++----- 1 file changed, 175 insertions(+), 20 deletions(-) diff --git a/tests/test_namespace.py b/tests/test_namespace.py index 6cf9515d32a..91669180c70 100644 --- a/tests/test_namespace.py +++ b/tests/test_namespace.py @@ -1,3 +1,5 @@ +import json + import edgedb from edb.schema import defines as s_def @@ -7,6 +9,12 @@ class TestNameSpace(tb.DatabaseTestCase): TRANSACTION_ISOLATION = False + async def assert_query_in_conn( + self, conn, query, exp_result + ): + res = await conn.query_json(query) + self.assertEqual(json.loads(res), exp_result) + async def test_create_drop_namespace(self): await self.con.execute("create namespace ns1;") await self.assert_query_result( @@ -72,32 +80,30 @@ async def test_use_show_namespace(self): conn1 = await self.connect(database=self.get_database_name()) conn2 = await self.connect(database=self.get_database_name()) try: - self.assertEqual((await conn2.query('show namespace;')), [s_def.DEFAULT_NS]) - self.assertEqual((await conn1.query('show namespace;')), [s_def.DEFAULT_NS]) + await self.assert_query_in_conn(conn1, 'show namespace;', [s_def.DEFAULT_NS]) + await self.assert_query_in_conn(conn2, 'show namespace;', [s_def.DEFAULT_NS]) # check seperated between connection await conn1.execute('use namespace temp1;') - self.assertEqual((await conn1.query('show namespace;')), ['temp1']) - self.assertEqual((await conn2.query('show namespace;')), [s_def.DEFAULT_NS]) + await self.assert_query_in_conn(conn1, 'show namespace;', ['temp1']) + await self.assert_query_in_conn(conn2, 'show namespace;', [s_def.DEFAULT_NS]) # check use - await conn1.execute('CONFIGURE SESSION SET __internal_testmode := true;' - 'create type A;' - 'CONFIGURE SESSION SET __internal_testmode := false;') - self.assertEqual( - ( - await conn1.query( - 'select count((select schema::ObjectType filter .name="default::A"))' - ) - ), + await conn1.execute( + 'CONFIGURE SESSION SET __internal_testmode := true;' + 'create type A;' + 'CONFIGURE SESSION SET __internal_testmode := false;' + ) + + await self.assert_query_in_conn( + conn1, + 'select count((select schema::ObjectType filter .name="default::A"))', [1] ) - self.assertEqual( - ( - await conn2.query( - 'select count((select schema::ObjectType filter .name="default::A"))' - ) - ), + + await self.assert_query_in_conn( + conn2, + 'select count((select schema::ObjectType filter .name="default::A"))', [0] ) @@ -106,7 +112,7 @@ async def test_use_show_namespace(self): with self.assertRaises(edgedb.QueryError): await conn1.query("select 1") - self.assertEqual((await conn1.query('show namespace;')), ['default']) + await self.assert_query_in_conn(conn1, 'show namespace;', [s_def.DEFAULT_NS]) finally: await conn1.aclose() await conn2.aclose() @@ -135,3 +141,152 @@ async def test_use_namespace_invalid(self): finally: await self.con.execute("drop namespace ns4;") + + async def test_concurrent_schema_version_change_between_ns(self): + await self.con.execute("create namespace temp1;") + await self.con.execute("create namespace temp2;") + + conn1 = await self.connect(database=self.get_database_name()) + conn2 = await self.connect(database=self.get_database_name()) + + try: + await conn1.execute('use namespace temp1;') + await conn2.execute('use namespace temp2;') + + await conn1.execute( + ''' + START MIGRATION TO { + module default { + type A5; + type Object5 { + required link a -> default::A5; + }; + }; + }; + ''' + ) + await conn1.execute('POPULATE MIGRATION') + async with conn2.transaction(): + await conn2.execute( + ''' + START MIGRATION TO { + module default { + type A6; + type Object6 { + required link a -> default::A6; + }; + }; + }; + POPULATE MIGRATION; + COMMIT MIGRATION; + ''' + ) + + await conn1.execute("COMMIT MIGRATION") + + await self.assert_query_in_conn( + conn1, + r""" + SELECT schema::ObjectType { + name, + links: { + target: {name} + } + FILTER .name = 'a' + ORDER BY .name + } + FILTER .name in {'default::Object5', 'default::Object6'}; + """, + [ + { + "name": "default::Object5", + "links": [ + { + "target": { + "name": "default::A5" + } + } + ] + } + ], + ) + + await self.assert_query_in_conn( + conn2, + r""" + SELECT schema::ObjectType { + name, + links: { + target: {name} + } + FILTER .name = 'a' + ORDER BY .name + } + FILTER .name in {'default::Object5', 'default::Object6'}; + """, + [ + { + "name": "default::Object6", + "links": [ + { + "target": { + "name": "default::A6" + } + } + ] + } + ], + ) + + finally: + await conn1.aclose() + await conn2.aclose() + await self.con.execute('drop namespace temp1;') + await self.con.execute('drop namespace temp2;') + + async def test_concurrent_schema_version_change_in_one_ns(self): + await self.con.execute("create namespace temp1;") + + conn1 = await self.connect(database=self.get_database_name()) + conn2 = await self.connect(database=self.get_database_name()) + + try: + await conn1.execute('use namespace temp1;') + await conn2.execute('use namespace temp1;') + + await conn1.execute( + ''' + START MIGRATION TO { + module default { + type A5; + type Object5 { + required link a -> default::A5; + }; + }; + }; + ''' + ) + await conn1.execute('POPULATE MIGRATION') + async with conn2.transaction(): + await conn2.execute( + ''' + START MIGRATION TO { + module default { + type A6; + type Object6 { + required link a -> default::A6; + }; + }; + }; + POPULATE MIGRATION; + COMMIT MIGRATION; + ''' + ) + + with self.assertRaises(edgedb.TransactionError): + await conn1.execute("COMMIT MIGRATION") + + finally: + await conn1.aclose() + await conn2.aclose() + await self.con.execute('drop namespace temp1;') From a8cf4968c18f273d85f9f92158b28e2442049d8f Mon Sep 17 00:00:00 2001 From: JMPRIEST Date: Mon, 5 Jun 2023 14:30:21 +0800 Subject: [PATCH 18/20] =?UTF-8?q?:sparkles:=20create/drop=20db/ns=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0pg=20event?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- edb/graphql/extension.pyx | 2 +- edb/server/compiler/compiler.py | 6 +++--- edb/server/dbview/dbview.pxd | 8 +++++--- edb/server/dbview/dbview.pyx | 10 ++++++++++ edb/server/pgcon/pgcon.pyx | 9 ++++++++- edb/server/protocol/binary_v0.pyx | 2 +- edb/server/protocol/execute.pyx | 27 ++++++++++++++++++++++++--- edb/server/server.py | 8 ++++++-- tests/test_namespace.py | 27 ++++++++++++++++++++++++--- 9 files changed, 82 insertions(+), 17 deletions(-) diff --git a/edb/graphql/extension.pyx b/edb/graphql/extension.pyx index 61f5b6201ef..15b293ed1b5 100644 --- a/edb/graphql/extension.pyx +++ b/edb/graphql/extension.pyx @@ -369,7 +369,7 @@ async def _execute( query_cache[cache_key] = redir key_vars2 = tuple(vars[k] for k in key_var_names) cache_key2 = ( - 'graphql', prepared_query, key_vars2, operation_name, dbver, query_only, module, limit + 'graphql', prepared_query, key_vars2, operation_name, dbver, query_only, namespace, module, limit ) query_cache[cache_key2] = qug, gql_op if gql_op.is_introspection: diff --git a/edb/server/compiler/compiler.py b/edb/server/compiler/compiler.py index 22d8cc3e5e8..77087fd784a 100644 --- a/edb/server/compiler/compiler.py +++ b/edb/server/compiler/compiler.py @@ -1976,13 +1976,13 @@ def _try_compile( is_script = statements_len > 1 if is_script and any(isinstance(stmt, qlast.UseNameSpaceCommand) for stmt in statements): - raise errors.ProtocolError( + raise errors.QueryError( 'USE NAMESPACE statement is not allowed to be used in script.' ) if isinstance(statements[0], qlast.UseNameSpaceCommand) and ctx.in_tx: - raise errors.ProtocolError( - 'USE NAMESPACE statement is not allowed to be used in transaction.' + raise errors.QueryError( + 'cannot execute USE NAMESPACE in a transaction' ) if ctx.skip_first: diff --git a/edb/server/dbview/dbview.pxd b/edb/server/dbview/dbview.pxd index c38af7fc30a..d863b01bc29 100644 --- a/edb/server/dbview/dbview.pxd +++ b/edb/server/dbview/dbview.pxd @@ -27,9 +27,11 @@ cpdef enum SideEffects: SchemaChanges = 1 << 0 DatabaseConfigChanges = 1 << 1 - InstanceConfigChanges = 1 << 2 - RoleChanges = 1 << 3 - GlobalSchemaChanges = 1 << 4 + DatabaseDrop = 1 << 2 + DatabaseCreate = 1 << 3 + InstanceConfigChanges = 1 << 4 + RoleChanges = 1 << 5 + GlobalSchemaChanges = 1 << 6 @cython.final diff --git a/edb/server/dbview/dbview.pyx b/edb/server/dbview/dbview.pyx index 3b99efa138a..6c071ef9a45 100644 --- a/edb/server/dbview/dbview.pyx +++ b/edb/server/dbview/dbview.pyx @@ -1186,6 +1186,7 @@ cdef class DatabaseConnectionView: not self._in_tx and side_effects and (side_effects & SideEffects.SchemaChanges) + and not (query_unit.create_ns or query_unit.drop_ns) ): self.save_schema_mutation( query_unit.namespace, @@ -1217,6 +1218,15 @@ cdef class DatabaseConnectionView: if query_unit.stdview_sqls: self._db.schedule_stdobj_inhview_update(query_unit.stdview_sqls) side_effects |= SideEffects.SchemaChanges + + if query_unit.create_ns or query_unit.drop_ns: + self._db.dbver = next_dbver() + side_effects |= SideEffects.SchemaChanges + if query_unit.create_db: + side_effects |= SideEffects.DatabaseCreate + if query_unit.drop_db: + side_effects |= SideEffects.DatabaseDrop + if query_unit.system_config: side_effects |= SideEffects.InstanceConfigChanges if query_unit.database_config: diff --git a/edb/server/pgcon/pgcon.pyx b/edb/server/pgcon/pgcon.pyx index f9143d41525..93ade27859d 100644 --- a/edb/server/pgcon/pgcon.pyx +++ b/edb/server/pgcon/pgcon.pyx @@ -1940,7 +1940,14 @@ cdef class PGConnection: if event == 'schema-changes': dbname = event_payload['dbname'] namespace = event_payload['namespace'] - self.server._on_remote_ddl(dbname, namespace) + drop_ns = event_payload['drop_ns'] + self.server._on_remote_ddl(dbname, namespace, drop_ns) + elif event == 'database-create': + dbname = event_payload['dbname'] + self.server._on_remote_ddl(dbname, namespace=None) + elif event == 'database-drop': + dbname = event_payload['dbname'] + self.server._on_after_drop_db(dbname) elif event == 'database-config-changes': dbname = event_payload['dbname'] self.server._on_remote_database_config_change(dbname) diff --git a/edb/server/protocol/binary_v0.pyx b/edb/server/protocol/binary_v0.pyx index 673baf23c5a..65d5e89d17b 100644 --- a/edb/server/protocol/binary_v0.pyx +++ b/edb/server/protocol/binary_v0.pyx @@ -912,7 +912,7 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): else: side_effects = _dbview.on_success(query_unit, new_types) if side_effects: - execute.signal_side_effects(_dbview, query_unit.namespace, side_effects) + execute.signal_side_effects(_dbview, query_unit, side_effects) if not _dbview.in_tx(): state = _dbview.serialize_state() if state is not orig_state: diff --git a/edb/server/protocol/execute.pyx b/edb/server/protocol/execute.pyx index 91e9d6f12f6..ea81c08f7da 100644 --- a/edb/server/protocol/execute.pyx +++ b/edb/server/protocol/execute.pyx @@ -161,7 +161,7 @@ async def execute( else: side_effects = dbv.on_success(query_unit, new_types) if side_effects: - signal_side_effects(dbv, query_unit.namespace, side_effects) + signal_side_effects(dbv, query_unit, side_effects) if not dbv.in_tx(): state = dbv.serialize_state() if state is not orig_state: @@ -327,7 +327,7 @@ async def execute_script( global_schema, cached_reflection, unit_group.affected_obj_ids ) if side_effects: - signal_side_effects(dbv, query_unit.namespace, side_effects) + signal_side_effects(dbv, query_unit, side_effects) if ( side_effects & dbview.SideEffects.SchemaChanges and group_mutation is not None @@ -379,17 +379,38 @@ async def execute_system_config( await conn.sql_execute(b'SELECT pg_reload_conf()') -def signal_side_effects(dbv, namespace, side_effects): +def signal_side_effects(dbv, query_unit, side_effects): server = dbv.server if not server._accept_new_tasks: return if side_effects & dbview.SideEffects.SchemaChanges: + if query_unit.create_ns: + namespace = query_unit.create_ns + else: + namespace = query_unit.namespace server.create_task( server._signal_sysevent( 'schema-changes', dbname=dbv.dbname, namespace=namespace, + drop_ns=query_unit.drop_ns + ), + interruptable=False, + ) + if side_effects & dbview.SideEffects.DatabaseCreate: + server.create_task( + server._signal_sysevent( + 'database-create', + dbname=query_unit.create_db, + ), + interruptable=False, + ) + if side_effects & dbview.SideEffects.DatabaseDrop: + server.create_task( + server._signal_sysevent( + 'database-drop', + dbname=query_unit.drop_db, ), interruptable=False, ) diff --git a/edb/server/server.py b/edb/server/server.py index 08991a112c4..0a39e4feaf3 100644 --- a/edb/server/server.py +++ b/edb/server/server.py @@ -1322,7 +1322,7 @@ async def _signal_sysevent(self, event, **kwargs): metrics.background_errors.inc(1.0, 'signal_sysevent') raise - def _on_remote_ddl(self, dbname, namespace): + def _on_remote_ddl(self, dbname, namespace, drop_ns=None): if not self._accept_new_tasks: return @@ -1330,7 +1330,11 @@ def _on_remote_ddl(self, dbname, namespace): # on the __edgedb_sysevent__ channel async def task(): try: - await self.introspect(dbname, namespace) + if drop_ns: + assert self._dbindex is not None + self._dbindex.unregister_ns(dbname, drop_ns) + else: + await self.introspect(dbname, namespace) except Exception: metrics.background_errors.inc(1.0, 'on_remote_ddl') raise diff --git a/tests/test_namespace.py b/tests/test_namespace.py index 91669180c70..479b71ea664 100644 --- a/tests/test_namespace.py +++ b/tests/test_namespace.py @@ -27,6 +27,27 @@ async def test_create_drop_namespace(self): [{'name': s_def.DEFAULT_NS}] ) + async def test_create_drop_namespace_invalid(self): + await self.con.execute("START TRANSACTION") + + with self.assertRaisesRegex( + edgedb.QueryError, + 'cannot execute CREATE NAMESPACE in a transaction', + ): + await self.con.execute("create namespace ns;") + + await self.con.execute("ROLLBACK") + + await self.con.execute("START TRANSACTION") + + with self.assertRaisesRegex( + edgedb.QueryError, + 'cannot execute DROP NAMESPACE in a transaction', + ): + await self.con.execute("drop namespace ns;") + + await self.con.execute("ROLLBACK") + async def test_create_namespace_invalid(self): with self.assertRaisesRegex( edgedb.SchemaDefinitionError, @@ -124,7 +145,7 @@ async def test_use_namespace_invalid(self): await self.con.execute("use namespace ns5;") with self.assertRaisesRegex( - edgedb.ProtocolError, + edgedb.QueryError, 'USE NAMESPACE statement is not allowed to be used in script.', ): await self.con.execute("use namespace ns4;select 1;") @@ -132,8 +153,8 @@ async def test_use_namespace_invalid(self): await self.con.execute("START TRANSACTION") with self.assertRaisesRegex( - edgedb.ProtocolError, - 'USE NAMESPACE statement is not allowed to be used in transaction.', + edgedb.QueryError, + 'cannot execute USE NAMESPACE in a transaction', ): await self.con.execute("use namespace ns4;") From 0f8e5ae1d00b968d273a0d61366cc782a5bc4bbb Mon Sep 17 00:00:00 2001 From: JMPRIEST Date: Wed, 7 Jun 2023 17:11:30 +0800 Subject: [PATCH 19/20] =?UTF-8?q?:white=5Fcheck=5Fmark:=20=E6=A3=80?= =?UTF-8?q?=E6=9F=A5configure=E7=9B=B8=E5=85=B3=E9=80=BB=E8=BE=91=EF=BC=8C?= =?UTF-8?q?=E6=9B=B4=E6=96=B0=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- edb/server/server.py | 7 +-- tests/test_server_config.py | 87 +++++++++++++++++++------------------ tests/test_server_proto.py | 25 +++++++++++ 3 files changed, 74 insertions(+), 45 deletions(-) diff --git a/edb/server/server.py b/edb/server/server.py index 0a39e4feaf3..3328f9e36ab 100644 --- a/edb/server/server.py +++ b/edb/server/server.py @@ -78,7 +78,7 @@ ADMIN_PLACEHOLDER = "" logger = logging.getLogger('edb.server') log_metrics = logging.getLogger('edb.server.metrics') -_RE_BYTES_REPL_NS = re.compile( +_RE_STR_REPL_NS = re.compile( r'(current_setting\([\']+)?' r'(edgedb)(\.|instdata|pub\.|pub;|pub\'|ss|std\.|std\'|std;|;)([\"a-z0-9_\-]+)?', ) @@ -86,6 +86,7 @@ def repl_ignore_setting(match_obj): maybe_setting, schema_name, tailing, maybe_domain_name = match_obj.groups() + # skip changing pg_catalog.current_setting('edgedb.xxx') if maybe_setting: return maybe_setting + schema_name + tailing + (maybe_domain_name or '') if maybe_domain_name: @@ -960,7 +961,7 @@ async def _load_instance_data(self): WHERE key = 'local_intro_query'; ''') - self._local_intro_query = _RE_BYTES_REPL_NS.sub( + self._local_intro_query = _RE_STR_REPL_NS.sub( repl_ignore_setting, local_intro_query.decode('utf-8'), ) @@ -1012,7 +1013,7 @@ async def _load_instance_data(self): else: ns_tpl_sql = tpldbdump.decode() - self._ns_tpl_sql = _RE_BYTES_REPL_NS.sub(repl_ignore_setting, ns_tpl_sql) + self._ns_tpl_sql = _RE_STR_REPL_NS.sub(repl_ignore_setting, ns_tpl_sql) finally: self._release_sys_pgcon() diff --git a/tests/test_server_config.py b/tests/test_server_config.py index 8793a0693e9..1823e48f308 100644 --- a/tests/test_server_config.py +++ b/tests/test_server_config.py @@ -709,57 +709,60 @@ async def test_server_proto_configure_03(self): ) async def test_server_proto_configure_04(self): - with self.assertRaisesRegex( - edgedb.UnsupportedFeatureError, - 'CONFIGURE SESSION INSERT is not supported'): - await self.con.query(''' - CONFIGURE SESSION INSERT TestSessionConfig {name := 'test_04'} - ''') - - with self.assertRaisesRegex( - edgedb.ConfigurationError, - "unrecognized configuration object 'Unrecognized'"): - await self.con.query(''' - CONFIGURE INSTANCE INSERT Unrecognized {name := 'test_04'} - ''') - - with self.assertRaisesRegex( - edgedb.QueryError, - "must not have a FILTER clause"): - await self.con.query(''' - CONFIGURE INSTANCE RESET __internal_testvalue FILTER 1 = 1; - ''') + try: + with self.assertRaisesRegex( + edgedb.UnsupportedFeatureError, + 'CONFIGURE SESSION INSERT is not supported'): + await self.con.query(''' + CONFIGURE SESSION INSERT TestSessionConfig {name := 'test_04'} + ''') - with self.assertRaisesRegex( - edgedb.QueryError, - "non-constant expression"): - await self.con.query(''' - CONFIGURE SESSION SET __internal_testmode := (random() = 0); - ''') + with self.assertRaisesRegex( + edgedb.ConfigurationError, + "unrecognized configuration object 'Unrecognized'"): + await self.con.query(''' + CONFIGURE INSTANCE INSERT Unrecognized {name := 'test_04'} + ''') - with self.assertRaisesRegex( - edgedb.ConfigurationError, - "'Subclass1' cannot be configured directly"): - await self.con.query(''' - CONFIGURE INSTANCE INSERT Subclass1 { - name := 'foo' - }; - ''') + with self.assertRaisesRegex( + edgedb.QueryError, + "must not have a FILTER clause"): + await self.con.query(''' + CONFIGURE INSTANCE RESET __internal_testvalue FILTER 1 = 1; + ''') - await self.con.query(''' - CONFIGURE INSTANCE INSERT TestInstanceConfig { - name := 'test_04', - } - ''') + with self.assertRaisesRegex( + edgedb.QueryError, + "non-constant expression"): + await self.con.query(''' + CONFIGURE SESSION SET __internal_testmode := (random() = 0); + ''') - with self.assertRaisesRegex( - edgedb.ConstraintViolationError, - "TestInstanceConfig.name violates exclusivity constraint"): + with self.assertRaisesRegex( + edgedb.ConfigurationError, + "'Subclass1' cannot be configured directly"): + await self.con.query(''' + CONFIGURE INSTANCE INSERT Subclass1 { + name := 'foo' + }; + ''') await self.con.query(''' CONFIGURE INSTANCE INSERT TestInstanceConfig { name := 'test_04', } ''') + with self.assertRaisesRegex( + edgedb.ConstraintViolationError, + "TestInstanceConfig.name violates exclusivity constraint"): + await self.con.query(''' + CONFIGURE INSTANCE INSERT TestInstanceConfig { + name := 'test_04', + } + ''') + finally: + await self.con.execute(''' + CONFIGURE INSTANCE RESET TestInstanceConfig; + ''') async def test_server_proto_configure_05(self): await self.con.execute(''' diff --git a/tests/test_server_proto.py b/tests/test_server_proto.py index 99998d9eddc..570a314607e 100644 --- a/tests/test_server_proto.py +++ b/tests/test_server_proto.py @@ -859,6 +859,7 @@ async def test_server_proto_wait_cancel_01(self): lock_key = tb.gen_lock_key() con2 = await self.connect(database=self.con.dbname) + await con2.execute(f"use namespace {self.test_ns}") await self.con.query('START TRANSACTION') await self.con.query( @@ -1416,6 +1417,7 @@ async def test_server_proto_tx_02(self): # to make sure that Opportunistic Execute isn't used. con2 = await self.connect(database=self.con.dbname) + await con2.execute(f"use namespace {self.test_ns}") try: with self.assertRaises(edgedb.DivisionByZeroError): @@ -1448,6 +1450,7 @@ async def test_server_proto_tx_03(self): # to make sure that "ROLLBACK" is cached. con2 = await self.connect(database=self.con.dbname) + await con2.execute(f"use namespace {self.test_ns}") try: for _ in range(5): @@ -1521,6 +1524,7 @@ async def test_server_proto_tx_06(self): query = 'SELECT 1' con2 = await self.connect(database=self.con.dbname) + await con2.execute(f"use namespace {self.test_ns}") try: for _ in range(5): self.assertEqual( @@ -1867,6 +1871,7 @@ async def test_server_proto_tx_16(self): async def test_server_proto_tx_17(self): con1 = self.con con2 = await self.connect(database=con1.dbname) + await con2.execute(f"use namespace {self.test_ns}") tx1 = con1.transaction() tx2 = con2.transaction() @@ -2183,6 +2188,10 @@ class TestServerProtoDDL(tb.DDLTestCase): TRANSACTION_ISOLATION = False + SETUP = ''' + CONFIGURE SESSION SET __internal_testmode := true; + ''' + async def test_server_proto_create_db_01(self): if not self.has_create_database: self.skipTest('create database is not supported by the backend') @@ -2224,6 +2233,8 @@ async def test_server_proto_query_cache_invalidate_01(self): con1 = self.con con2 = await self.connect(database=con1.dbname) try: + await con2.execute(f"use namespace {self.test_ns}") + await con2.execute("CONFIGURE SESSION SET __internal_testmode := true;") await con2.execute(f''' CREATE TYPE {typename} {{ CREATE REQUIRED PROPERTY prop1 -> std::str; @@ -2271,6 +2282,8 @@ async def test_server_proto_query_cache_invalidate_02(self): con1 = self.con con2 = await self.connect(database=con1.dbname) try: + await con2.execute(f"use namespace {self.test_ns}") + await con2.execute("CONFIGURE SESSION SET __internal_testmode := true;") await con2.query(f''' CREATE TYPE {typename} {{ CREATE REQUIRED PROPERTY prop1 -> std::str; @@ -2326,6 +2339,8 @@ async def test_server_proto_query_cache_invalidate_03(self): con1 = self.con con2 = await self.connect(database=con1.dbname) try: + await con2.execute(f"use namespace {self.test_ns}") + await con2.execute("CONFIGURE SESSION SET __internal_testmode := true;") await con2.execute(f''' CREATE TYPE {typename} {{ CREATE REQUIRED PROPERTY prop1 -> array; @@ -2364,6 +2379,8 @@ async def test_server_proto_query_cache_invalidate_03(self): await con1.query(query), edgedb.Set([[1, 23]])) + await con2.execute("CONFIGURE SESSION SET __internal_testmode := false;") + finally: await con2.aclose() @@ -2373,6 +2390,8 @@ async def test_server_proto_query_cache_invalidate_04(self): con1 = self.con con2 = await self.connect(database=con1.dbname) try: + await con2.execute(f"use namespace {self.test_ns}") + await con2.execute("CONFIGURE SESSION SET __internal_testmode := true;") await con2.execute(f''' CREATE TYPE {typename} {{ CREATE REQUIRED PROPERTY prop1 -> std::str; @@ -2420,6 +2439,8 @@ async def test_server_proto_query_cache_invalidate_05(self): con1 = self.con con2 = await self.connect(database=con1.dbname) try: + await con2.execute(f"use namespace {self.test_ns}") + await con2.execute("CONFIGURE SESSION SET __internal_testmode := true;") await con2.execute(f''' CREATE TYPE {typename} {{ CREATE REQUIRED PROPERTY prop1 -> std::str; @@ -2477,6 +2498,8 @@ async def test_server_proto_query_cache_invalidate_06(self): con1 = self.con con2 = await self.connect(database=con1.dbname) try: + await con2.execute(f"use namespace {self.test_ns}") + await con2.execute("CONFIGURE SESSION SET __internal_testmode := true;") await con2.execute(f''' CREATE TYPE Foo{typename}; @@ -2534,6 +2557,8 @@ async def test_server_proto_query_cache_invalidate_07(self): con1 = self.con con2 = await self.connect(database=con1.dbname) try: + await con2.execute(f"use namespace {self.test_ns}") + await con2.execute("CONFIGURE SESSION SET __internal_testmode := true;") await con2.execute(f''' CREATE TYPE Foo{typename}; From 98f52046a0d3d10dcd7be45c9713ebb07039125b Mon Sep 17 00:00:00 2001 From: JMPRIEST Date: Thu, 8 Jun 2023 16:12:53 +0800 Subject: [PATCH 20/20] =?UTF-8?q?:white=5Fcheck=5Fmark:=20=E6=9B=B4?= =?UTF-8?q?=E6=96=B0create=20type=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- edb/server/compiler/compiler.py | 3 ++- edb/server/protocol/binary.pyx | 2 +- edb/testbase/http.py | 2 +- edb/testbase/server.py | 5 +++-- tests/test_http_create_type.py | 3 +++ 5 files changed, 10 insertions(+), 5 deletions(-) diff --git a/edb/server/compiler/compiler.py b/edb/server/compiler/compiler.py index 77087fd784a..a9bd1a7174d 100644 --- a/edb/server/compiler/compiler.py +++ b/edb/server/compiler/compiler.py @@ -579,7 +579,7 @@ def _compile_schema_storage_in_delta( or '__script' in args ): # schema version and migration - # update should always goes to main block + # update should always go to main block block.add_command(cmd) else: sp_block.add_command(cmd) @@ -2883,6 +2883,7 @@ def describe_database_restore( protocol_version: Tuple[int, int], external_view: Dict[str, str] ) -> RestoreDescriptor: + pg_common.NAMESPACE = namespace schema_object_ids = { ( s_name.name_from_string(name), diff --git a/edb/server/protocol/binary.pyx b/edb/server/protocol/binary.pyx index 9b9bcac72c2..a424d2f306b 100644 --- a/edb/server/protocol/binary.pyx +++ b/edb/server/protocol/binary.pyx @@ -2200,7 +2200,7 @@ cdef class EdgeConnection(frontend.FrontendConnection): ''', ) for ns, (schema_ddl, schema_ids, blocks, external_views) in schema_info_by_ns.items(): - logger.info(ns) + logger.info(f"Restoring namespace: {ns}...") user_schema = _dbview.get_user_schema(ns) _dbview.decode_state(sertypes.NULL_TYPE_ID.bytes, b'', ns) diff --git a/edb/testbase/http.py b/edb/testbase/http.py index a277f6d2471..78d3b50bd1e 100644 --- a/edb/testbase/http.py +++ b/edb/testbase/http.py @@ -378,7 +378,7 @@ def get_api_path(cls): def create_type(self, body): req_data = body.as_dict() - req_data['namespace'] = os.environ.get('EDGEDB_TEST_CASES_NAMESPACE', defines.DEFAULT_NS) + req_data['namespace'] = self.test_ns req = urllib.request.Request(self.http_addr, method='POST') req.add_header('Content-Type', 'application/json') req.add_header('testmode', '1') diff --git a/edb/testbase/server.py b/edb/testbase/server.py index 7215a48e32b..276be7c0c98 100644 --- a/edb/testbase/server.py +++ b/edb/testbase/server.py @@ -1254,13 +1254,14 @@ def shape(self): class BaseQueryTestCase(DatabaseTestCase): BASE_TEST_CLASS = True - test_ns: str + test_ns: str = None @classmethod def setUpClass(cls): super().setUpClass() class_set_up = os.environ.get('EDGEDB_TEST_CASES_SET_UP', 'run') - cls.test_ns = os.environ.get('EDGEDB_TEST_CASES_NAMESPACE', defines.DEFAULT_NS) + if cls.test_ns is None: + cls.test_ns = os.environ.get('EDGEDB_TEST_CASES_NAMESPACE', defines.DEFAULT_NS) if class_set_up != 'skip': if cls.test_ns != defines.DEFAULT_NS: diff --git a/tests/test_http_create_type.py b/tests/test_http_create_type.py index 64648503352..da9b251bae6 100644 --- a/tests/test_http_create_type.py +++ b/tests/test_http_create_type.py @@ -4,6 +4,7 @@ import edgedb +from edb.server import defines from edb.testbase import http as http_tb from edb.testbase import server as server_tb @@ -996,6 +997,8 @@ async def test_dml_reject(self): class TestHttpCreateTypeDumpRestore(TestHttpCreateType, server_tb.StableDumpTestCase): + test_ns = defines.DEFAULT_NS + async def prepare(self): await self.prepare_external_db(dbname=f"{self.get_database_name()}_restored")