diff --git a/src/Python/Inline/Literal.hs b/src/Python/Inline/Literal.hs index 76fed0e..8ec221b 100644 --- a/src/Python/Inline/Literal.hs +++ b/src/Python/Inline/Literal.hs @@ -20,6 +20,7 @@ import Data.Char import Data.Int import Data.Word import Data.Set qualified as Set +import Data.Map.Strict qualified as Map import Foreign.Ptr import Foreign.C.Types import Foreign.Storable @@ -408,6 +409,33 @@ instance (FromPy a, Ord a) => FromPy (Set.Set a) where pure $! Set.insert a s) Set.empty + +instance (ToPy k, ToPy v, Ord k) => ToPy (Map.Map k v) where + basicToPy dct = runProgram $ do + p_dict <- takeOwnership =<< checkNull basicNewDict + progPy $ do + let loop [] = p_dict <$ incref p_dict + loop ((k,v):xs) = basicToPy k >>= \case + NULL -> mustThrowPyError + p_k -> flip finally (decref p_k) $ basicToPy v >>= \case + NULL -> mustThrowPyError + p_v -> Py [CU.exp| int { PyDict_SetItem($(PyObject *p_dict), $(PyObject* p_k), $(PyObject *p_v)) }|] >>= \case + 0 -> loop xs + _ -> nullPtr <$ decref p_v + loop $ Map.toList dct + +instance (FromPy k, FromPy v, Ord k) => FromPy (Map.Map k v) where + basicFromPy p_dct = basicGetIter p_dct >>= \case + NULL -> do Py [C.exp| void { PyErr_Clear() } |] + throwM BadPyType + p_iter -> foldPyIterable p_iter + (\m p -> do k <- basicFromPy p + v <- Py [CU.exp| PyObject* { PyDict_GetItem($(PyObject* p_dct), $(PyObject *p)) }|] >>= \case + NULL -> throwM BadPyType + p_v -> basicFromPy p_v + pure $! Map.insert k v m) + Map.empty + -- | Fold over iterable. Function takes ownership over iterator. foldPyIterable :: Ptr PyObject -- ^ Python iterator (not checked) diff --git a/test/TST/Roundtrip.hs b/test/TST/Roundtrip.hs index c7cfd4f..2d337bd 100644 --- a/test/TST/Roundtrip.hs +++ b/test/TST/Roundtrip.hs @@ -6,6 +6,7 @@ import Data.Int import Data.Word import Data.Typeable import Data.Set (Set) +import Data.Map.Strict (Map) import Foreign.C.Types import Test.Tasty @@ -53,6 +54,7 @@ tests = testGroup "Roundtrip" , testRoundtrip @[Int] , testRoundtrip @[[Int]] , testRoundtrip @(Set Int) + , testRoundtrip @(Map Int Int) -- , testRoundtrip @String -- Trips on zero byte as it should ] , testGroup "OutOfRange" diff --git a/test/TST/ToPy.hs b/test/TST/ToPy.hs index a712fb7..12b48d1 100644 --- a/test/TST/ToPy.hs +++ b/test/TST/ToPy.hs @@ -2,6 +2,7 @@ module TST.ToPy (tests) where import Data.Set qualified as Set +import Data.Map.Strict qualified as Map import Test.Tasty import Test.Tasty.HUnit import Python.Inline @@ -38,5 +39,11 @@ tests = testGroup "ToPy" in [py_| assert x_hs == {1,3,5} |] , testCase "set unhashable" $ runPy $ let x = Set.fromList [[1], [5], [3::Int]] - in throwsPy [py_| assert x_hs == {1,3,5} |] + in throwsPy [py_| x_hs |] + , testCase "dict" $ runPy $ + let x = Map.fromList [(1,10), (5,50), (3,30)] :: Map.Map Int Int + in [py_| assert x_hs == {1:10, 3:30, 5:50} |] + , testCase "dict unhashable" $ runPy $ + let x = Map.fromList [([1],10), ([5],50), ([3],30)] :: Map.Map [Int] Int + in throwsPy [py_| x_hs |] ]