Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions hech-interp.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,10 @@ common base
, dependent-map
, constraints-extras
, dependent-sum
, dependent-sum-template
, constraints
, tiktoken
, megaparsec < 9.7
, vector-sized
, HList == 0.5.4.0


common binary-base
Expand All @@ -60,7 +59,7 @@ library
exposed-modules:
GPT2
GPT2.Loader
GPT2.Model
GPT2.CachedModel
SafeTensors
hs-source-dirs: src
ghc-options:
Expand Down
4 changes: 2 additions & 2 deletions src/GPT2.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module GPT2 (module GPT2.Loader, module GPT2.Model) where
module GPT2 (module GPT2.Loader, module GPT2.CachedModel) where

import GPT2.Loader
import GPT2.Model
import GPT2.CachedModel
575 changes: 337 additions & 238 deletions src/GPT2/CachedModel.hs

Large diffs are not rendered by default.

243 changes: 127 additions & 116 deletions src/GPT2/HListExtensions.hs
Original file line number Diff line number Diff line change
@@ -1,137 +1,148 @@
{-# LANGUAGE AllowAmbiguousTypes #-}
{- Enriching Torch.HList -}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ParallelListComp #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE Strict #-}
{-# LANGUAGE StarIsType #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE NoStarIsType #-}
{-# OPTIONS_GHC -fconstraint-solver-iterations=0 #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}

module GPT2.HListExtensions
( HList,
HReplicateR,
HAppendListR,
HAppendFD,
ApplyAB,
applyAB,
HScanl,
hTail,
Nat2HNat,
HNat2Nat,
hLast,
HReplicate,
hScanl,
hScanlTail,
hReplicate,
hReplicateF
)
where

import Data.HList hiding (ApplyAB, applyAB)
import GHC.TypeLits
import Prelude hiding (cos, exp, sin)

-- | TODO: Originally, ApplyAB doesn't have this fundep.
class ApplyAB f a b | f a -> b where
applyAB :: f -> a -> b

-- class HLast (xs :: [Type]) (y :: Type) | xs -> y where
-- hLast :: HList xs -> y

-- instance HLast '[x] x where
-- hLast :: HList '[x] -> x
-- hLast (x `HCons` HNil) = x

-- instance (HLast (b ': xs) y) => HLast (a ': b ': xs) y where
-- hLast (_`HCons` xs) = hLast xs

-- class HScanr f z ls rs where
-- -- | Correspond to `scanr :: (a -> b -> b) -> b -> [a] -> [b]`
-- hScanr :: f -> z -> HList ls -> HList rs

-- instance (lz ~ '[z]) => HScanr f z '[] lz where
-- hScanr _ z _ = HCons (z, HNil)

-- instance
-- ( ApplyAB f (x, r) s,
-- HScanr f z xs (r ': rs),
-- srrs ~ (s ': r ': rs)
-- ) =>
-- HScanr f z (x ': xs) srrs
-- where
-- hScanr f z (HCons (x, xs)) =
-- case hScanr f z xs :: HList (r ': rs) of
-- HCons (r, rs) -> HCons (applyAB f (x, r) :: s, HCons (r, rs))

class HScanl f z ls rs where
-- | Correspond to `scanl :: (b -> a -> b) -> b -> [a] -> [b]`
hScanl :: f -> z -> HList ls -> HList rs

scanr' :: (a -> b -> b) -> b -> [a] -> [b]
scanr' _ ini [] = [ini]
scanr' f acc (x : xs) = f x y : ys
where
ys@(y : _) = scanr' f acc xs

scanl' :: ((a, b) -> b) -> b -> [a] -> [b]
scanl' f acc xs = acc : [f (x, y) | x <- xs | y <- scanl' f acc xs]
{-# LANGUAGE PatternSynonyms #-}

instance (lz ~ '[z]) => HScanl f z '[] lz where
hScanl _ z _ = z `HCons` HNil
module GPT2.HListExtensions where

instance
( ApplyAB f (z, x) s,
HScanl f s xs rs
import Data.Kind (Type)
import GHC.TypeLits
import Unsafe.Coerce (unsafeCoerce)
import GHC.Exts (IsList (..))

data family HList (l :: [Type])

data instance HList '[] = HNil

data instance HList (x ': xs) = x `HCons` HList xs

pattern (:.) :: forall x (xs :: [Type]). x -> HList xs -> HList (x : xs)
pattern (:.) x xs = HCons x xs

type family HReplicateR (n :: Nat) (e :: Type) :: [Type] where
HReplicateR 0 _ = '[]
HReplicateR n e = e ': HReplicateR (n - 1) e

-- Ideally, we can write `hmap` as following:
-- However, because `HReplicateR` is non-injective, we don't necessarily know that
-- `HNil ~ HList (HReplicateR n a) => n = 0`.
-- Consequently, given `foo :: HList (HReplicateR n a)`, you simply can't
-- pattern match it. And hence, you can't write a `hmap`.
--
-- Ideally:
-- hmap :: (a -> b) -> HList (HReplicateR n a) -> HList (HReplicateR n b)
-- hmap _ _ = _
-- hmap f (HCons x xs) = HCons (f x) (hmap f xs)

-- So we will explicitly establish connections between type variables.
-- We will not let the compiler infer structure from `HReplicateR`,
-- turning
-- `HMapC` says: Given elements types `a` and `b`, and list structures `as` and `bs`,
-- we define a way to transform an `HList as` into an `HList bs` using a function `a -> b`.
class HMapC a b as bs | a b as -> bs, a b bs -> as where
hmapC :: (a -> b) -> HList as -> HList bs

-- An empty list maps to an empty list, regardless of the element types
instance HMapC a b '[] '[] where
hmapC _ HNil = HNil

-- If we know how to map a list of as to bs,
-- then we can map a list with an a at the front to a list with a b at the front
instance (HMapC a b as bs) => HMapC a b (a ': as) (b ': bs) where
hmapC f (HCons x xs) = HCons (f x) (hmapC f xs)

mapFst ::
forall a b n.
(HMapC (a, b) a (HReplicateR n (a, b)) (HReplicateR n a)) =>
HList (HReplicateR n (a, b)) ->
HList (HReplicateR n a)
mapFst = hmapC @(a, b) fst

l1 :: HList (HReplicateR 2 (Int, String))
l1 = HCons (10, "asdf") (HCons (20, "zcv") HNil)

l2 :: HList (HReplicateR 2 Int)
l2 = mapFst @Int @_ @2 l1

class HLast (xs :: [Type]) (y :: Type) | xs -> y where
hLast :: HList xs -> y

instance HLast '[x] x where
hLast (x `HCons` HNil) = x

instance (HLast (b ': xs) y) => HLast (a ': b ': xs) y where
hLast (_ `HCons` xs) = hLast xs

class HTail (xs :: [Type]) (ys :: [Type]) | xs -> ys where
hTail :: HList xs -> HList ys

instance HTail (a ': as) as where
hTail (_ `HCons` xs) = xs

-- HScanlC applies a binary operation cumulatively over a list
-- The type signature mirrors scanl: (b -> a -> b) -> b -> [a] -> [b]
-- But for heterogeneous lists, we need to track the evolving accumulator type
class HScanlC a b as bs | a b as -> bs where
hScanlC :: (b -> a -> b) -> b -> HList as -> HList bs

-- Base case: scanning an empty list produces a singleton list containing just the initial value
instance HScanlC a b '[] '[b] where
hScanlC _ b HNil = HCons b HNil

instance (HScanlC a b as bs) => HScanlC a b (a ': as) (b ': bs) where
hScanlC :: (b -> a -> b) -> b -> HList (a : as) -> HList (b : bs)
hScanlC f acc (HCons x xs) =
let newAcc = f acc x
in HCons acc (hScanlC f newAcc xs)

-- HScanlC
-- Integer Integer (HReplicateR n Int) (HReplicateR (1 + n) Int)

l3 = hScanlC @Int (+) 0 l2

ok = hTail l3

foo ::
forall n.
( HScanlC
Integer
Integer
(HReplicateR n Int)
(HReplicateR (1 + n) Int)
) =>
HScanl f z (x ': xs) (z ': rs)
where
hScanl f z (x `HCons` xs) =
z `HCons` hScanl f (applyAB f (z, x) :: s) xs
HList (HReplicateR n Int) ->
HList (HReplicateR (n + 1) Int)
foo = hScanlC (+) 0

class HScanlTail f z ls rs where
-- | Correspond to `scanl :: (b -> a -> b) -> b -> [a] -> [b]`
hScanlTail :: f -> z -> HList ls -> HList rs

instance (lz ~ '[]) => HScanlTail f z '[] lz where
hScanlTail _ _ _ = HNil
instance IsList (Maybe (HList '[(a :: Type)])) where
type Item (Maybe (HList '[(a :: Type)])) = a
fromList [x] = liftA2 (:.) (Just x) (Just HNil)
fromList _ = Nothing
toList Nothing = []
toList (Just (x :. HNil)) = [x]

instance
( ApplyAB f (z, x) s,
HScanlTail f s xs rs
( IsList (Maybe (HList (a ': as))),
a ~ Item (Maybe (HList (a ': as)))
) =>
HScanlTail f z (x ': xs) (z ': rs)
IsList (Maybe (HList ((a :: Type) ': a ': as)))
where
hScanlTail f z (x `HCons` xs) =
z `HCons` hScanlTail f (applyAB f (z, x) :: s) xs

data AddFunc = AddFunc

instance ApplyAB AddFunc (String, Int) String where
applyAB _ (x, y) = show x ++ show y

type family Nat2HNat (n :: Nat) :: HNat where
Nat2HNat 0 = HZero
Nat2HNat n = HSucc (Nat2HNat (n - 1))

example :: IO ()
example = do
-- Create an input list: [1, 2, 3]
let inputList = (1 :: Int) .*. (2 :: Int) .*. HNil

-- Apply hScanr with AddFunc, starting value 0
result :: HList (HReplicateR (Nat2HNat 3) String)
result = hScanl AddFunc "asdf" inputList

ok = hTail result
-- Result should be equivalent to scanr (+) 0 [1,2,3] = [6,5,3,0]
print result
type Item (Maybe (HList (a ': a ': as))) = a
fromList (x : xs) = liftA2 (:.) (Just x) (fromList xs)
fromList _ = Nothing
toList Nothing = []
toList (Just (x :. xs)) = x : toList (Just xs)
11 changes: 8 additions & 3 deletions src/GPT2/Loader.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@ import Control.Monad.IO.Class (liftIO)
import Control.Monad.Trans.Maybe
import GHC.Exts (IsList (fromList))
import GHC.Utils.Monad (zipWith3M)
import GPT2.Model
import GPT2.CachedModel
import GPT2.HListExtensions
import SafeTensors
import qualified Torch as UT
import qualified Torch.DType as D
import Torch.HList
import Torch.Typed hiding
( MultiheadAttention,
TransformerLayer,
TransformerMLP,
keys,
transformerLM,
transformerLM, HList, HReplicateR
)
import Prelude hiding (lookup)

Expand All @@ -30,6 +30,8 @@ type NumAttnLayers = 12

type NumHeads = 12

type HeadDim = 64

type FFNDim = 3072

type PaddingIdx = 0
Expand All @@ -42,6 +44,7 @@ type Model =
GPT2
NumAttnLayers
NumHeads
HeadDim
FFNDim
PaddingIdx
MaxSeqLen
Expand All @@ -54,6 +57,7 @@ type ModelSpec =
GPT2Spec
NumAttnLayers
NumHeads
HeadDim
FFNDim
PaddingIdx
MaxSeqLen
Expand Down Expand Up @@ -213,6 +217,7 @@ loadGPT2FromSafeTensors ::
( GPT2
NumAttnLayers
NumHeads
HeadDim
FFNDim
PaddingIdx
MaxSeqLen
Expand Down
Loading