From 904d054a0ee3e75d8232a90bfca0ed8c6e6b0932 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michael=20Elsd=C3=B6rfer?= Date: Mon, 17 Sep 2018 16:23:39 +0100 Subject: [PATCH 1/3] Use switchable io backend. --- gino/engine.py | 6 +++--- gino/loops/__init__.py | 10 ++++++++++ gino/loops/asyncio.py | 9 +++++++++ gino/loops/trio.py | 10 ++++++++++ 4 files changed, 32 insertions(+), 3 deletions(-) create mode 100644 gino/loops/__init__.py create mode 100644 gino/loops/asyncio.py create mode 100644 gino/loops/trio.py diff --git a/gino/engine.py b/gino/engine.py index 1ff1f1bd..840f7052 100644 --- a/gino/engine.py +++ b/gino/engine.py @@ -1,4 +1,3 @@ -import asyncio import collections import functools import sys @@ -9,6 +8,7 @@ from sqlalchemy.sql import schema from .transaction import GinoTransaction +from .loops import get_loop class _BaseDBAPIConnection: @@ -52,7 +52,7 @@ def __init__(self, cursor_cls, pool=None): super().__init__(cursor_cls) self._pool = pool self._conn = None - self._lock = asyncio.Lock() + self._lock = get_loop().Lock() @property def raw_connection(self): @@ -64,7 +64,7 @@ async def _acquire(self, timeout): await self._lock.acquire() else: before = time.monotonic() - await asyncio.wait_for(self._lock.acquire(), timeout=timeout) + await get_loop().wait_for_with_timeout(self._lock.acquire, timeout) after = time.monotonic() timeout -= after - before if self._conn is None: diff --git a/gino/loops/__init__.py b/gino/loops/__init__.py new file mode 100644 index 00000000..262b71a8 --- /dev/null +++ b/gino/loops/__init__.py @@ -0,0 +1,10 @@ +from gino.loops.asyncio import AsyncioLoop +import sniffio + + +def get_loop(): + if sniffio.current_async_library() == 'trio': + from .trio import TrioLoop + return TrioLoop + + return AsyncioLoop diff --git a/gino/loops/asyncio.py b/gino/loops/asyncio.py new file mode 100644 index 00000000..f220194f --- /dev/null +++ b/gino/loops/asyncio.py @@ -0,0 +1,9 @@ +import asyncio + + +class AsyncioLoop: + @staticmethod + async def wait_for_with_timeout(fn, timeout): + return await asyncio.wait_for(fn(), timeout=timeout) + + Lock = asyncio.Lock diff --git a/gino/loops/trio.py b/gino/loops/trio.py new file mode 100644 index 00000000..040e47f5 --- /dev/null +++ b/gino/loops/trio.py @@ -0,0 +1,10 @@ +import trio + + +class TrioLoop: + @staticmethod + async def wait_for_with_timeout(fn, timeout): + with trio.fail_after(timeout): + return await fn() + + Lock = trio.Lock \ No newline at end of file From 608bbe6b3ba85bc3eca71044bf962219572a1981 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michael=20Elsd=C3=B6rfer?= Date: Mon, 17 Sep 2018 16:24:13 +0100 Subject: [PATCH 2/3] Add a riopg backend. --- gino/dialects/base.py | 20 +- gino/dialects/riopg.py | 456 +++++++++++++++++++++++++++++++++++++++++ gino/strategies.py | 6 +- setup.py | 4 +- 4 files changed, 478 insertions(+), 8 deletions(-) create mode 100644 gino/dialects/riopg.py diff --git a/gino/dialects/base.py b/gino/dialects/base.py index 18ae9e67..f2706471 100644 --- a/gino/dialects/base.py +++ b/gino/dialects/base.py @@ -189,12 +189,20 @@ async def execute(self, one=False, return_model=True, status=False): param_groups = [] for params in context.parameters: - replace_params = [] - for val in params: - if asyncio.iscoroutine(val): - val = await val - replace_params.append(val) - param_groups.append(replace_params) + if isinstance(params, dict): + replace_params = {} + for name, val in params.items(): + if asyncio.iscoroutine(val): + val = await val + replace_params[name] = val + param_groups.append(replace_params) + else: + replace_params = [] + for val in params: + if asyncio.iscoroutine(val): + val = await val + replace_params.append(val) + param_groups.append(replace_params) cursor = context.cursor if context.executemany: diff --git a/gino/dialects/riopg.py b/gino/dialects/riopg.py new file mode 100644 index 00000000..f3ba43b3 --- /dev/null +++ b/gino/dialects/riopg.py @@ -0,0 +1,456 @@ +# TODO: This one needs porting + +import inspect +import itertools +import time + +import riopg +import psycopg2 +import trio +from sqlalchemy import util, exc, sql +from sqlalchemy.dialects.postgresql import ( # noqa: F401 + ARRAY, + CreateEnumType, + DropEnumType, + JSON, + JSONB +) +from sqlalchemy.dialects.postgresql.base import ( + ENUM, + PGCompiler, + PGDialect, + PGExecutionContext, +) +from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2, PGCompiler_psycopg2, PGExecutionContext_psycopg2 +from sqlalchemy.engine.url import URL +from sqlalchemy.sql import sqltypes + +from . import base + + +class RiopgDBAPI(base.BaseDBAPI): + Error = psycopg2.Error + + +class RiopgCompiler(PGCompiler_psycopg2): + pass + + +# noinspection PyAbstractClass +class RiopgExecutionContext(base.ExecutionContextOverride, + PGExecutionContext_psycopg2): + + async def _execute_scalar(self, stmt, type_): + conn = self.root_connection + if isinstance(stmt, util.text_type) and \ + not self.dialect.supports_unicode_statements: + stmt = self.dialect._encoder(stmt)[0] + + if self.dialect.positional: + default_params = self.dialect.execute_sequence_format() + else: + default_params = {} + + conn._cursor_execute(self.cursor, stmt, default_params, context=self) + r = await self.cursor.async_execute(stmt, None, default_params, 1) + r = r[0][0] + if type_ is not None: + # apply type post processors to the result + proc = type_._cached_result_processor( + self.dialect, + self.cursor.description[0][1] + ) + if proc: + return proc(r) + return r + + +class RiopgIterator: + def __init__(self, context, iterator): + self._context = context + self._iterator = iterator + + async def __anext__(self): + row = await self._iterator.__anext__() + return self._context.process_rows([row])[0] + + +class RiopgCursor(base.Cursor): + def __init__(self, context, cursor): + self._context = context + self._cursor = cursor + + async def many(self, n, *, timeout=base.DEFAULT): + if timeout is base.DEFAULT: + timeout = self._context.timeout + rows = await self._cursor.fetch(n, timeout=timeout) + return self._context.process_rows(rows) + + async def next(self, *, timeout=base.DEFAULT): + if timeout is base.DEFAULT: + timeout = self._context.timeout + row = await self._cursor.fetchrow(timeout=timeout) + if not row: + return None + return self._context.process_rows([row])[0] + + async def forward(self, n, *, timeout=base.DEFAULT): + if timeout is base.DEFAULT: + timeout = self._context.timeout + await self._cursor.forward(n, timeout=timeout) + + +class PreparedStatement(base.PreparedStatement): + def __init__(self, prepared, clause=None): + super().__init__(clause) + self._prepared = prepared + + def _get_iterator(self, *params, **kwargs): + return RiopgIterator( + self.context, self._prepared.cursor(*params, **kwargs).__aiter__()) + + async def _get_cursor(self, *params, **kwargs): + iterator = await self._prepared.cursor(*params, **kwargs) + return RiopgCursor(self.context, iterator) + + async def _execute(self, params, one): + if one: + rv = await self._prepared.fetchrow(*params) + if rv is None: + rv = [] + else: + rv = [rv] + else: + rv = await self._prepared.fetch(*params) + return self._prepared.get_statusmsg(), rv + + +class DBAPICursor(base.DBAPICursor): + def __init__(self, dbapi_conn): + self._conn = dbapi_conn + self._status = None + + async def prepare(self, context, clause=None): + # XXX https://gist.github.com/dvarrazzo/3797445 ? + timeout = context.timeout + if timeout is None: + conn = await self._conn.acquire(timeout=timeout) + else: + before = time.monotonic() + conn = await self._conn.acquire(timeout=timeout) + after = time.monotonic() + timeout -= after - before + prepared = await conn.prepare(context.statement, timeout=timeout) + try: + self._attributes = prepared.get_attributes() + except TypeError: # asyncpg <= 0.12.0 + self._attributes = [] + rv = PreparedStatement(prepared, clause) + rv.context = context + return rv + + async def async_execute(self, query, timeout, args, limit=0, many=False): + if many: + # ripog does not support this yet. Also psycopg2 just + # uses a loop? + # https://github.com/psycopg/psycopg2/issues/491 + raise RuntimeError('Not yet supported.') + + with trio.open_cancel_scope() as scope: + if timeout: + scope.deadline = trio.current_time() + timeout + + conn = await self._conn.acquire() + async with (await conn.cursor()) as cursor: + await cursor.execute(query, args) + if limit > 0: + result = await cursor.fetchall() + else: + result = [] + + self._description = cursor.description or [] + self._status = cursor.statusmessage + return result + + @property + def description(self): + return self._description + + def get_statusmsg(self): + return self._status + + +class Pool(base.Pool): + def __init__(self, url, **kwargs): + self._url = url + self._kwargs = kwargs + self._pool = None + + async def _init(self): + args = self._kwargs.copy() + # psycopg2 does not deal well with postgres+riopg urls. + url = URL( + drivername='postgres', + username=self._url.username, + password=self._url.password, + database=self._url.database, + host=self._url.host, + port=self._url.port + ) + args.update( + dsn=str(url) + ) + self._pool = await riopg.create_pool(**args) + return self + + def __await__(self): + return self._init().__await__() + + @property + def raw_pool(self): + return self._pool + + async def acquire(self, *, timeout=None): + with trio.open_cancel_scope() as scope: + if timeout: + scope.deadline = trio.current_time() + timeout + return await self._pool.acquire() + + async def release(self, conn): + await self._pool.release(conn) + + async def close(self): + await self._pool.close() + + +class Transaction(base.Transaction): + def __init__(self, conn): + self._conn = conn + + async def begin(self): + #await self._conn.start() + pass + + async def commit(self): + await self._conn.commit() + + async def rollback(self): + await self._conn.rollback() + + +class AsyncEnum(ENUM): + async def create_async(self, bind=None, checkfirst=True): + if not checkfirst or \ + not await bind.dialect.has_type( + bind, self.name, schema=self.schema): + await bind.status(CreateEnumType(self)) + + async def drop_async(self, bind=None, checkfirst=True): + if not checkfirst or \ + await bind.dialect.has_type(bind, self.name, + schema=self.schema): + await bind.status(DropEnumType(self)) + + async def _on_table_create_async(self, target, bind, checkfirst=False, + **kw): + if checkfirst or ( + not self.metadata and + not kw.get('_is_metadata_operation', False)) and \ + not self._check_for_name_in_memos(checkfirst, kw): + await self.create_async(bind=bind, checkfirst=checkfirst) + + async def _on_table_drop_async(self, target, bind, checkfirst=False, **kw): + if not self.metadata and \ + not kw.get('_is_metadata_operation', False) and \ + not self._check_for_name_in_memos(checkfirst, kw): + await self.drop_async(bind=bind, checkfirst=checkfirst) + + async def _on_metadata_create_async(self, target, bind, checkfirst=False, + **kw): + if not self._check_for_name_in_memos(checkfirst, kw): + await self.create_async(bind=bind, checkfirst=checkfirst) + + async def _on_metadata_drop_async(self, target, bind, checkfirst=False, + **kw): + if not self._check_for_name_in_memos(checkfirst, kw): + await self.drop_async(bind=bind, checkfirst=checkfirst) + + +# noinspection PyAbstractClass +class RiopgDialect(PGDialect_psycopg2, base.AsyncDialectMixin): + driver = 'riopg' + dbapi_class = RiopgDBAPI + statement_compiler = RiopgCompiler + execution_ctx_cls = RiopgExecutionContext + cursor_cls = DBAPICursor + init_kwargs = set() + + # init_kwargs = set(itertools.chain( + # *[inspect.getfullargspec(f).kwonlydefaults.keys() for f in + # [riopg.create_pool, riopg.Connection.open]])) + colspecs = util.update_copy( + PGDialect.colspecs, + { + ENUM: AsyncEnum, + sqltypes.Enum: AsyncEnum, + } + ) + + def __init__(self, *args, **kwargs): + self._pool_kwargs = {} + for k in self.init_kwargs: + if k in kwargs: + self._pool_kwargs[k] = kwargs.pop(k) + super().__init__(*args, **kwargs) + self._init_mixin() + + async def init_pool(self, url, loop): + # XXX: riopg supports a connection_factory argument we might be able to use. + #init = self.on_connect() + return await Pool(url, **self._pool_kwargs) + + # noinspection PyMethodMayBeStatic + def transaction(self, raw_conn, args, kwargs): + return Transaction(raw_conn) + + def on_connect(self): + if self.isolation_level is not None: + async def connect(conn): + await self.set_isolation_level(conn, self.isolation_level) + return connect + else: + return None + + async def set_isolation_level(self, connection, level): + """ + Given an asyncpg connection, set its isolation level. + + """ + level = level.replace('_', ' ') + if level not in self._isolation_lookup: + raise exc.ArgumentError( + "Invalid value '%s' for isolation_level. " + "Valid isolation levels for %s are %s" % + (level, self.name, ", ".join(self._isolation_lookup)) + ) + await connection.execute( + "SET SESSION CHARACTERISTICS AS TRANSACTION " + "ISOLATION LEVEL %s" % level) + await connection.execute("COMMIT") + + async def get_isolation_level(self, connection): + """ + Given an asyncpg connection, return its isolation level. + + """ + val = await connection.fetchval('show transaction isolation level') + return val.upper() + + async def has_schema(self, connection, schema): + query = ("select nspname from pg_namespace " + "where lower(nspname)=:schema") + row = await connection.first( + sql.text( + query, + bindparams=[ + sql.bindparam( + 'schema', util.text_type(schema.lower()), + type_=sqltypes.Unicode)] + ) + ) + + return bool(row) + + async def has_table(self, connection, table_name, schema=None): + # seems like case gets folded in pg_class... + if schema is None: + row = await connection.first( + sql.text( + "select relname from pg_class c join pg_namespace n on " + "n.oid=c.relnamespace where " + "pg_catalog.pg_table_is_visible(c.oid) " + "and relname=:name", + bindparams=[ + sql.bindparam('name', util.text_type(table_name), + type_=sqltypes.Unicode)] + ) + ) + else: + row = await connection.first( + sql.text( + "select relname from pg_class c join pg_namespace n on " + "n.oid=c.relnamespace where n.nspname=:schema and " + "relname=:name", + bindparams=[ + sql.bindparam('name', + util.text_type(table_name), + type_=sqltypes.Unicode), + sql.bindparam('schema', + util.text_type(schema), + type_=sqltypes.Unicode)] + ) + ) + return bool(row) + + async def has_sequence(self, connection, sequence_name, schema=None): + if schema is None: + row = await connection.first( + sql.text( + "SELECT relname FROM pg_class c join pg_namespace n on " + "n.oid=c.relnamespace where relkind='S' and " + "n.nspname=current_schema() " + "and relname=:name", + bindparams=[ + sql.bindparam('name', util.text_type(sequence_name), + type_=sqltypes.Unicode) + ] + ) + ) + else: + row = await connection.first( + sql.text( + "SELECT relname FROM pg_class c join pg_namespace n on " + "n.oid=c.relnamespace where relkind='S' and " + "n.nspname=:schema and relname=:name", + bindparams=[ + sql.bindparam('name', util.text_type(sequence_name), + type_=sqltypes.Unicode), + sql.bindparam('schema', + util.text_type(schema), + type_=sqltypes.Unicode) + ] + ) + ) + + return bool(row) + + async def has_type(self, connection, type_name, schema=None): + if schema is not None: + query = """ + SELECT EXISTS ( + SELECT * FROM pg_catalog.pg_type t, pg_catalog.pg_namespace n + WHERE t.typnamespace = n.oid + AND t.typname = :typname + AND n.nspname = :nspname + ) + """ + query = sql.text(query) + else: + query = """ + SELECT EXISTS ( + SELECT * FROM pg_catalog.pg_type t + WHERE t.typname = :typname + AND pg_type_is_visible(t.oid) + ) + """ + query = sql.text(query) + query = query.bindparams( + sql.bindparam('typname', + util.text_type(type_name), type_=sqltypes.Unicode), + ) + if schema is not None: + query = query.bindparams( + sql.bindparam('nspname', + util.text_type(schema), type_=sqltypes.Unicode), + ) + return bool(await connection.scalar(query)) diff --git a/gino/strategies.py b/gino/strategies.py index de1739fb..94bcd555 100644 --- a/gino/strategies.py +++ b/gino/strategies.py @@ -1,4 +1,5 @@ import asyncio +from sniffio import current_async_library from sqlalchemy.engine import url from sqlalchemy import util @@ -17,7 +18,10 @@ async def create(self, name_or_url, loop=None, **kwargs): if loop is None: loop = asyncio.get_event_loop() if u.drivername in {'postgresql', 'postgres'}: - u.drivername = 'postgresql+asyncpg' + if current_async_library() == 'trio': + u.drivername = 'postgresql+riopg' + else: + u.drivername = 'postgresql+asyncpg' dialect_cls = u.get_dialect() diff --git a/setup.py b/setup.py index e20ec365..6295662f 100644 --- a/setup.py +++ b/setup.py @@ -46,10 +46,12 @@ def req_file(filename): [sqlalchemy.dialects] postgresql.asyncpg = gino.dialects.asyncpg:AsyncpgDialect asyncpg = gino.dialects.asyncpg:AsyncpgDialect + postgresql.riopg = gino.dialects.riopg:RiopgDialect + riopg = gino.dialects.riopg:RiopgDialect """, license="BSD license", zip_safe=False, - keywords='orm asyncio sqlalchemy asyncpg python3 sanic aiohttp tornado', + keywords='orm asyncio sqlalchemy asyncpg python3 sanic aiohttp tornado trio', classifiers=[ 'Development Status :: 4 - Beta', 'Intended Audience :: Developers', From 257aa5a55b0ecb50e5fa823dd2558b0d8cf67380 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michael=20Elsd=C3=B6rfer?= Date: Thu, 28 Mar 2019 16:44:56 +0000 Subject: [PATCH 3/3] Update to latest gino changes. --- gino/dialects/riopg.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/gino/dialects/riopg.py b/gino/dialects/riopg.py index f3ba43b3..e29d5a6b 100644 --- a/gino/dialects/riopg.py +++ b/gino/dialects/riopg.py @@ -156,7 +156,7 @@ async def async_execute(self, query, timeout, args, limit=0, many=False): # https://github.com/psycopg/psycopg2/issues/491 raise RuntimeError('Not yet supported.') - with trio.open_cancel_scope() as scope: + with trio.CancelScope() as scope: if timeout: scope.deadline = trio.current_time() + timeout @@ -211,7 +211,7 @@ def raw_pool(self): return self._pool async def acquire(self, *, timeout=None): - with trio.open_cancel_scope() as scope: + with trio.CancelScope() as scope: if timeout: scope.deadline = trio.current_time() + timeout return await self._pool.acquire() @@ -304,10 +304,12 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._init_mixin() - async def init_pool(self, url, loop): + async def init_pool(self, url, loop, pool_class=None): # XXX: riopg supports a connection_factory argument we might be able to use. #init = self.on_connect() - return await Pool(url, **self._pool_kwargs) + if pool_class is None: + pool_class = Pool + return await pool_class(url, **self._pool_kwargs) # noinspection PyMethodMayBeStatic def transaction(self, raw_conn, args, kwargs):