diff --git a/colanderalchemy/schema.py b/colanderalchemy/schema.py index 4925e76..b7a88f6 100644 --- a/colanderalchemy/schema.py +++ b/colanderalchemy/schema.py @@ -682,28 +682,72 @@ def objectify(self, dict_, context=None): Default: ``None``. Defaults to instantiating a new instance of the mapped class associated with this schema. """ + + + """ To persist related data, the recursive call to ``objectify`` needs + information about the corresponding sqlalchemy object that is to be + updated. This object is defined by the colanderalchemy obj, the + session it is defined in, the corresponding mapped class and + by the identity of the colanderalchemy object. + + If this identity is given by a primary key (pk), the corresponding + object can be constructed. If there is no primary key data on obj, + it is a new object and needs to be added to the session. + """ + def get_context(obj, session, class_, pk): + '''return context of obj in session''' + if isinstance(pk, tuple): + ident = tuple(obj.get(v,None) for v in pk) + else: + ident = obj.get(pk,None) + context = session.query(class_).get(ident) if ident else None + + return context + mapper = self.inspector - context = mapper.class_() if context is None else context - for attr in dict_: - if mapper.has_property(attr): + context = context if context else mapper.class_() + insp = inspect(context, raiseerr=False) + session = insp.session if (insp and insp.session) else None + + for attr in dict_: + if attr in mapper.relationships.keys(): + # handle relationship prop = mapper.get_property(attr) - if hasattr(prop, 'mapper'): - cls = prop.mapper.class_ - if prop.uselist: - # Sequence of objects - value = [self[attr].children[0].objectify(obj) - for obj in dict_[attr]] - else: - # Single object - value = self[attr].objectify(dict_[attr]) + prop_class = prop.mapper.class_ + prop_pk = tuple(v.key for v in inspect(prop_class).primary_key) + if len(prop_pk) == 1: + prop_pk = prop_pk[0] + + if prop.uselist: + # relationship is x_to_many, value is list + subschema = self[attr].children[0] + value = [subschema.objectify( + obj, + get_context(obj, + session, + prop_class, + prop_pk + ) + ) for obj in dict_[attr]] else: - value = dict_[attr] - if value is colander.null: - # `colander.null` is never an appropriate - # value to be placed on an SQLAlchemy object - # so we translate it into `None`. - value = None - setattr(context, attr, value) + # relationship is x_to_one, value is not a list + + subschema = self[attr] + obj = dict_[attr] + value = subschema.objectify( + obj, + get_context(obj, + session, + prop_class, + prop_pk) + ) + + elif attr in mapper.columns.keys(): + # handle column + value = dict_[attr] + if value is colander.null: + value = None + else: # Ignore attributes if they are not mapped log.debug( @@ -712,8 +756,11 @@ def objectify(self, dict_, context=None): attr, self ) continue - + + # persist value + setattr(context, attr, value) return context + def clone(self): cloned = self.__class__(self.class_, diff --git a/tests/__init__.py b/tests/__init__.py index e2b34a9..2ff3b68 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -5,6 +5,8 @@ # This module is part of ColanderAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import tests.test_schema as test_schema +#import tests.test_schema as test_schema, tests.test_persist_data as test_persist_data -__all__ = ['test_schema'] +from tests import test_schema, test_persist_data + +__all__ = ['test_schema', 'test_persist_data'] diff --git a/tests/models.py b/tests/models.py index c64433e..1f89803 100644 --- a/tests/models.py +++ b/tests/models.py @@ -17,6 +17,7 @@ Float, ForeignKey, Integer, + String, Numeric, Time, Unicode, @@ -171,3 +172,4 @@ class Baz(Base): foo_id = Column(Integer, ForeignKey('foos.id')) foo = relationship('Foo', backref='bazs') + \ No newline at end of file diff --git a/tests/test_persist_data.py b/tests/test_persist_data.py new file mode 100644 index 0000000..bb41736 --- /dev/null +++ b/tests/test_persist_data.py @@ -0,0 +1,109 @@ +import colanderalchemy + +from sqlalchemy import ( + event, + create_engine, + Column, + ForeignKey, + Integer, + String, +) +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import (sessionmaker, relationship) +from sqlalchemy.engine import Engine + +import logging +import sys + +if sys.version_info[0] == 2 and sys.version_info[1] < 7: + # In Python < 2.7 use unittest2. + import unittest2 as unittest +else: + import unittest + +logging.basicConfig(level=logging.WARNING) +log = logging.getLogger(__name__) + + +@event.listens_for(Engine, "connect") +def set_sqlite_pragma(dbapi_connection, connection_record): + """Define referential integrity """ + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + log = logging.getLogger(__name__) + log.info("PRAGMA foreign_keys=ON") + + +Base = declarative_base() + +class A(Base): + __tablename__='a' + id = Column(Integer, primary_key=True) + v = Column(String()) + + id_b = Column(Integer, ForeignKey('b.id')) + bvalue = relationship('B') + + cvalues = relationship('C', + secondary='acassociations', + back_populates='avalues') + +class B(Base): + __tablename__ = 'b' + id = Column(Integer, primary_key=True) + v = String() + +class C(Base): + __tablename__ = 'c' + id = Column(Integer, primary_key=True) + v = Column(String()) + + avalues = relationship('A', + secondary='acassociations', + back_populates='cvalues') + +class ACAssociation(Base): + __tablename__ = 'acassociations' + id_a = Column(Integer, + ForeignKey('a.id', + ondelete='CASCADE', + onupdate='CASCADE'), + primary_key=True) + + id_c = Column(Integer, + ForeignKey('c.id', + ondelete='CASCADE', + onupdate='CASCADE'), + primary_key=True) + + +class Tests_persist_relation(unittest.TestCase): + + def setUp(self): + engine = create_engine('sqlite:///:memory:', echo=True) + Session = sessionmaker(bind=engine) + self.session = Session() + Base.metadata.create_all(engine) + + def tearDown(self): + self.session.close() + + def test_persist_relation(self): + # create an object + a_1 = A(v='a_1', bvalue=B(v='b'), cvalues=[C(v='c_1'), C(v='c_2')]) + #a_1 = A(v='a_1', cvalues=[C(v='c_1'), C(v='c_2')]) + self.session.add(a_1) + self.session.commit() + + # create a SQLAlchemySchemaNode + schema = colanderalchemy.SQLAlchemySchemaNode(A) + + # get data from a_1 + appstruct = schema.dictify(a_1) + + # objectify appstruct to a_1 + schema.objectify(appstruct, a_1) + + # should not fail + self.session.commit()