From 5e208746f2d4a9bea1c2169024eef1f86f0a9c5d Mon Sep 17 00:00:00 2001 From: Andrew Chen Date: Mon, 14 Apr 2025 22:57:45 -0400 Subject: [PATCH 1/3] Oh snap I'm almost there --- src/GPT2/CachedModel.hs | 441 ++++++++++++++++++++++-------------- src/GPT2/HListExtensions.hs | 230 +++++++++---------- src/GPT2/Loader.hs | 7 +- src/infer.hs | 13 +- 4 files changed, 397 insertions(+), 294 deletions(-) diff --git a/src/GPT2/CachedModel.hs b/src/GPT2/CachedModel.hs index 9a36e3d..19769b4 100644 --- a/src/GPT2/CachedModel.hs +++ b/src/GPT2/CachedModel.hs @@ -591,7 +591,7 @@ data -- | positional embedding tPosEmbedding :: Embedding 'Nothing maxSeqLen dmodel 'Constant dtype device, -- | transformer layers - tLayers :: HList (HReplicateR (Nat2HNat numAttnLayers) (TransformerLayer dmodel nhead ffnDim dtype device)), + tLayers :: HList (HReplicateR (numAttnLayers) (TransformerLayer dmodel nhead ffnDim dtype device)), -- | final layer norm tFinalLN :: LayerNorm '[dmodel] dtype device, -- | final output projection @@ -604,7 +604,7 @@ deriving instance ( Show ( HList ( HReplicateR - (Nat2HNat numAttnLayers) + (numAttnLayers) ( TransformerLayer dmodel nhead @@ -617,38 +617,53 @@ deriving instance ) => Show (GPT2 numAttnLayers nhead dhead ffnDim paddingIdx maxSeqLen vocabSize dmodel dtype device) -instance - ( layers - ~ HReplicateR - (Nat2HNat numAttnLayers) - ( TransformerLayer - dmodel - nhead - ffnDim - dtype - device - ), - Parameterized - ( HList - layers - ), - HAppendFD - (Parameters (HList layers)) - '[ Parameter device dtype '[dmodel], - Parameter device dtype '[dmodel], - Parameter device dtype '[vocabSize, dmodel], - Parameter device dtype '[vocabSize] - ] - ( HAppendListR - (Parameters (HList layers)) - '[ Parameter device dtype '[dmodel], - Parameter device dtype '[dmodel], - Parameter device dtype '[vocabSize, dmodel], - Parameter device dtype '[vocabSize] - ] - ) - ) => - Parameterized (GPT2 numAttnLayers nhead dhead ffnDim paddingIdx maxSeqLen vocabSize dmodel dtype device) +-- TODO: I don't get how a Parametrized instance is derived. +-- instance +-- ( layers +-- ~ HReplicateR +-- ( numAttnLayers) +-- ( TransformerLayer +-- dmodel +-- nhead +-- ffnDim +-- dtype +-- device +-- ), +-- Parameterized +-- ( HList +-- layers +-- ), +-- HAppendFD +-- (Parameters (HList layers)) +-- '[ Parameter device dtype '[dmodel], +-- Parameter device dtype '[dmodel], +-- Parameter device dtype '[vocabSize, dmodel], +-- Parameter device dtype '[vocabSize] +-- ] +-- ( HAppendListR +-- (Parameters (HList layers)) +-- '[ Parameter device dtype '[dmodel], +-- Parameter device dtype '[dmodel], +-- Parameter device dtype '[vocabSize, dmodel], +-- Parameter device dtype '[vocabSize] +-- ] +-- ), +-- Tensor.HAppendFD +-- (Parameters (HList layers)) +-- [ Parameter device dtype '[dmodel], +-- Parameter device dtype '[dmodel], +-- Parameter device dtype [vocabSize, dmodel], +-- Parameter device dtype '[vocabSize] +-- ] +-- ( Parameters (HList layers) +-- Tensor.++ [ Parameter device dtype '[dmodel], +-- Parameter device dtype '[dmodel], +-- Parameter device dtype [vocabSize, dmodel], +-- Parameter device dtype '[vocabSize] +-- ] +-- ) +-- ) => +-- Parameterized (GPT2 numAttnLayers nhead dhead ffnDim paddingIdx maxSeqLen vocabSize dmodel dtype device) newtype FoldLayers @@ -661,40 +676,6 @@ newtype flAttentionMask :: Maybe (Tensor device dtype '[batchSize, seqLen, seqLen]) } -instance - ( 1 <= nhead, - dmodel ~ (dhead * nhead), - All KnownNat '[dmodel, nhead, seqLen, batchSize, dhead], - IsSuffixOf '[dmodel] '[batchSize, seqLen, dmodel], - KnownDType dtype, - StandardFloatingPointDTypeValidation device dtype, - MatMulDTypeIsValid device dtype, - BasicArithmeticDTypeIsValid device dtype, - GeluDTypeIsValid device dtype, - dtype ~ SumDType dtype, - SumDTypeIsValid device dtype, - KnownDevice device, - MeanDTypeValidation device dtype, - AllDimsPositive '[batchSize, seqLen, dmodel] - ) => - ApplyAB - (FoldLayers batchSize seqLen dtype device) - ( ( Tensor device dtype '[batchSize, seqLen, dmodel], - DMap - (BlockCache device dtype batchSize seqLen dmodel dhead nhead) - Identity - ), - TransformerLayer dmodel nhead ffnDim dtype device - ) - ( Tensor device dtype '[batchSize, seqLen, dmodel], - DMap - (BlockCache device dtype batchSize seqLen dmodel dhead nhead) - Identity - ) - where - applyAB FoldLayers {..} ((x, _), layer) = do - let (res, cache) = transformerLayer layer flAttentionMask x in (res, cache) - transformerLM :: forall numAttnLayers @@ -709,47 +690,102 @@ transformerLM :: batchSize dtype device. - ( All KnownNat '[vocabSize, paddingIdx, dmodel, seqLen, batchSize], + ( All KnownNat '[vocabSize, paddingIdx, dmodel, seqLen, batchSize, numAttnLayers, nhead, dhead], + StandardFloatingPointDTypeValidation + device + dtype, + GeluDTypeIsValid device dtype, + MatMulDTypeIsValid device dtype, + SumDTypeIsValid device dtype, IsSuffixOf '[dmodel] '[batchSize, seqLen, dmodel], paddingIdx + 1 <= vocabSize, 1 <= seqLen, + 1 <= nhead, BasicArithmeticDTypeIsValid device dtype, ComparisonDTypeIsValid device dtype, ComparisonDTypeIsValid device 'D.Int64, - -- HLast - -- ( HReplicateR - -- numAttnLayers - -- ( Tensor device dtype [batchSize, seqLen, dmodel], - -- DMap - -- (BlockCache device dtype batchSize seqLen dmodel dhead nhead) - -- Identity - -- ) - -- ) - -- ( Tensor device dtype [batchSize, seqLen, dmodel], - -- DMap - -- (BlockCache device dtype batchSize seqLen dmodel dhead nhead) - -- Identity - -- ), - -- HScanlTail - -- (FoldLayers batchSize seqLen dtype device) - -- ( Tensor device dtype '[batchSize, seqLen, dmodel], - -- DMap - -- (BlockCache device dtype batchSize seqLen dmodel dhead nhead) - -- Identity - -- ) - -- (HReplicateR numAttnLayers (TransformerLayer dmodel nhead ffnDim dtype device)) - -- ( HReplicateR - -- numAttnLayers - -- ( Tensor device dtype '[batchSize, seqLen, dmodel], - -- DMap - -- (BlockCache device dtype batchSize seqLen dmodel dhead nhead) - -- Identity - -- ) - -- ), KnownDType dtype, KnownDevice device, MeanDTypeValidation device dtype, - AllDimsPositive '[batchSize, seqLen, dmodel] + AllDimsPositive + '[batchSize, seqLen, dmodel], + SumDType dtype ~ dtype, + (dhead * nhead) ~ dmodel, + HScanlC + (TransformerLayer dmodel nhead ffnDim dtype device) -- cur + ( Tensor device dtype [batchSize, seqLen, dmodel], -- acc + DMap + (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + Identity + ) + ( HReplicateR -- as + numAttnLayers + (TransformerLayer dmodel nhead ffnDim dtype device) + ) + ( HReplicateR -- bs + (1 + numAttnLayers) + ( Tensor device dtype [batchSize, seqLen, dmodel], + DMap + (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + Identity + ) + ), + HTail + ( HReplicateR + (1 + numAttnLayers) + ( Tensor device dtype [batchSize, seqLen, dmodel], + DMap + (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + Identity + ) + ) + ( HReplicateR + numAttnLayers + ( Tensor device dtype [batchSize, seqLen, dmodel], + DMap + (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + Identity + ) + ), + HMapC + ( Tensor device dtype [batchSize, seqLen, dmodel], + DMap + (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + Identity + ) + ( DMap + (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + Identity + ) + ( HReplicateR + numAttnLayers + ( Tensor device dtype [batchSize, seqLen, dmodel], + DMap + (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + Identity + ) + ) + ( HReplicateR + numAttnLayers + ( DMap + (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + Identity + ) + ), + HLast + ( HReplicateR + numAttnLayers + ( Tensor device dtype [batchSize, seqLen, dmodel], + DMap + (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + Identity + ) + ) + ( Tensor device dtype [batchSize, seqLen, dmodel], + DMap + (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + Identity + ) ) => GPT2 numAttnLayers nhead dhead ffnDim paddingIdx maxSeqLen vocabSize dmodel dtype device -> Tensor device 'D.Int64 '[batchSize, seqLen] -> @@ -778,64 +814,136 @@ transformerLM GPT2 {..} xTokens = do let emp :: DMap (BlockCache device dtype batchSize seqLen dmodel dhead nhead) Identity emp = empty - intermediateLayerOutputs :: + allOutputs :: HList ( HReplicateR - (Nat2HNat numAttnLayers) - ( Tensor device dtype '[batchSize, seqLen, dmodel], + (numAttnLayers + 1) + ( Tensor device dtype [batchSize, seqLen, dmodel], DMap (BlockCache device dtype batchSize seqLen dmodel dhead nhead) Identity ) ) - intermediateLayerOutputs = hScanlTail (FoldLayers (Just attentionMask')) (x', emp) tLayers - (finalOut, _) = hLast intermediateLayerOutputs + allOutputs = + hScanlC @(TransformerLayer dmodel nhead ffnDim dtype device) + (\(b, _) a -> transformerLayer a (Just attentionMask') b) + (x', emp) + tLayers + intermediateOutputs = hTail allOutputs + intermediateCaches = + hmapC + @( Tensor device dtype '[batchSize, seqLen, dmodel], + DMap + (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + Identity + ) + @( DMap + (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + Identity + ) + snd + intermediateOutputs + (finalOut, _) = hLast intermediateOutputs finalLnCache = layerNormForwardCached tFinalLN finalOut finalDist = forward tProj $ runIdentity (finalLnCache ! Normalized) - (finalDist, fromList [Embed ==> x, PosEmbed ==> x', LnFinal ==> finalLnCache]) + (finalDist, fromList [Embed ==> x, PosEmbed ==> x', LnFinal ==> finalLnCache, Blocks ==> intermediateCaches]) instance - ( All KnownNat '[vocabSize, paddingIdx, dmodel, seqLen, batchSize, seqLen], + ( All KnownNat '[vocabSize, paddingIdx, dmodel, seqLen, batchSize, seqLen, numAttnLayers, nhead, dhead], + StandardFloatingPointDTypeValidation + device + dtype, + GeluDTypeIsValid device dtype, + MatMulDTypeIsValid device dtype, + SumDTypeIsValid device dtype, IsSuffixOf '[dmodel] '[batchSize, seqLen, dmodel], paddingIdx + 1 <= vocabSize, + SumDType dtype ~ dtype, + (dhead * nhead) ~ dmodel, + 1 <= nhead, 1 <= seqLen, - -- HLast - -- ( HReplicateR - -- numAttnLayers - -- ( Tensor device dtype [batchSize, seqLen, dmodel], - -- DMap - -- (BlockCache device dtype batchSize seqLen dmodel dhead nhead) - -- Identity - -- ) - -- ) - -- ( Tensor device dtype [batchSize, seqLen, dmodel], - -- DMap - -- (BlockCache device dtype batchSize seqLen dmodel dhead nhead) - -- Identity - -- ), - -- HScanlTail - -- (FoldLayers batchSize seqLen dtype device) - -- ( Tensor device dtype '[batchSize, seqLen, dmodel], - -- DMap - -- (BlockCache device dtype batchSize seqLen dmodel dhead nhead) - -- Identity - -- ) - -- (HReplicateR numAttnLayers (TransformerLayer dmodel nhead ffnDim dtype device)) - -- ( HReplicateR - -- numAttnLayers - -- ( Tensor device dtype '[batchSize, seqLen, dmodel], - -- DMap - -- (BlockCache device dtype batchSize seqLen dmodel dhead nhead) - -- Identity - -- ) - -- ), BasicArithmeticDTypeIsValid device dtype, ComparisonDTypeIsValid device dtype, ComparisonDTypeIsValid device 'D.Int64, KnownDType dtype, KnownDevice device, MeanDTypeValidation device dtype, - AllDimsPositive '[batchSize, seqLen, dmodel] + AllDimsPositive '[batchSize, seqLen, dmodel], + HScanlC + (TransformerLayer dmodel nhead ffnDim dtype device) -- cur + ( Tensor device dtype [batchSize, seqLen, dmodel], -- acc + DMap + (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + Identity + ) + ( HReplicateR -- as + numAttnLayers + (TransformerLayer dmodel nhead ffnDim dtype device) + ) + ( HReplicateR -- bs + (1 + numAttnLayers) + ( Tensor device dtype [batchSize, seqLen, dmodel], + DMap + (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + Identity + ) + ), + HTail + ( HReplicateR + (1 + numAttnLayers) + ( Tensor device dtype [batchSize, seqLen, dmodel], + DMap + (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + Identity + ) + ) + ( HReplicateR + numAttnLayers + ( Tensor device dtype [batchSize, seqLen, dmodel], + DMap + (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + Identity + ) + ), + HMapC + ( Tensor device dtype [batchSize, seqLen, dmodel], + DMap + (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + Identity + ) + ( DMap + (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + Identity + ) + ( HReplicateR + numAttnLayers + ( Tensor device dtype [batchSize, seqLen, dmodel], + DMap + (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + Identity + ) + ) + ( HReplicateR + numAttnLayers + ( DMap + (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + Identity + ) + ), + HLast + ( HReplicateR + numAttnLayers + ( Tensor device dtype [batchSize, seqLen, dmodel], + DMap + (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + Identity + ) + ) + ( Tensor device dtype [batchSize, seqLen, dmodel], + DMap + (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + Identity + ) ) => HasForward (GPT2 numAttnLayers nhead dhead ffnDim paddingIdx seqLen vocabSize dmodel dtype device) @@ -878,32 +986,31 @@ sinusoidal = weights = stack @2 (sin radians Tensor.:. cos radians Tensor.:. Tensor.HNil) in reshape weights -instance - ( paddingIdx <= vocabSize, - 1 <= maxSeqLen, - 1 <= vocabSize - paddingIdx, - 1 <= Div dmodel 2, - (((vocabSize - paddingIdx) - 1) + (1 + paddingIdx)) ~ vocabSize, - (Div dmodel 2 * 2) ~ dmodel, - All KnownNat '[ffnDim, paddingIdx, vocabSize, maxSeqLen, dmodel], - HReplicate (Nat2HNat numAttnLayers) (TransformerLayerSpec dmodel nhead ffnDim dtype device), - A.Randomizable - (HList (HReplicateR (Nat2HNat numAttnLayers) (TransformerLayerSpec dmodel nhead ffnDim dtype device))) - (HList (HReplicateR (Nat2HNat numAttnLayers) (TransformerLayer dmodel nhead ffnDim dtype device))), - KnownDType dtype, - RandDTypeIsValid device dtype, - StandardFloatingPointDTypeValidation device 'D.Float, - BasicArithmeticDTypeIsValid device 'D.Float, - KnownDevice device - ) => - A.Randomizable - (GPT2Spec numAttnLayers nhead dhead ffnDim paddingIdx maxSeqLen vocabSize dmodel dtype device) - (GPT2 numAttnLayers nhead dhead ffnDim paddingIdx maxSeqLen vocabSize dmodel dtype device) - where - sample GPT2Spec {..} = - GPT2 - <$> A.sample (LearnedEmbeddingWithRandomInitSpec @('Just paddingIdx)) - <*> A.sample (ConstEmbeddingSpec @'Nothing (Torch.Typed.Tensor.toDType sinusoidal)) - <*> A.sample (hReplicate @(Nat2HNat numAttnLayers) lmLayerSpec) - <*> A.sample (LayerNormSpec epsSpec'') - <*> A.sample LinearSpec \ No newline at end of file +-- instance +-- ( paddingIdx <= vocabSize, +-- 1 <= maxSeqLen, +-- 1 <= vocabSize - paddingIdx, +-- 1 <= Div dmodel 2, +-- (((vocabSize - paddingIdx) - 1) + (1 + paddingIdx)) ~ vocabSize, +-- (Div dmodel 2 * 2) ~ dmodel, +-- All KnownNat '[ffnDim, paddingIdx, vocabSize, maxSeqLen, dmodel], +-- A.Randomizable +-- (HList (HReplicateR ( numAttnLayers) (TransformerLayerSpec dmodel nhead ffnDim dtype device))) +-- (HList (HReplicateR ( numAttnLayers) (TransformerLayer dmodel nhead ffnDim dtype device))), +-- KnownDType dtype, +-- RandDTypeIsValid device dtype, +-- StandardFloatingPointDTypeValidation device 'D.Float, +-- BasicArithmeticDTypeIsValid device 'D.Float, +-- KnownDevice device +-- ) => +-- A.Randomizable +-- (GPT2Spec numAttnLayers nhead dhead ffnDim paddingIdx maxSeqLen vocabSize dmodel dtype device) +-- (GPT2 numAttnLayers nhead dhead ffnDim paddingIdx maxSeqLen vocabSize dmodel dtype device) +-- where +-- sample GPT2Spec {..} = +-- GPT2 +-- <$> A.sample (LearnedEmbeddingWithRandomInitSpec @('Just paddingIdx)) +-- <*> A.sample (ConstEmbeddingSpec @'Nothing (Torch.Typed.Tensor.toDType sinusoidal)) +-- <*> A.sample (hReplicate (Proxy @( numAttnLayers)) lmLayerSpec) +-- <*> A.sample (LayerNormSpec epsSpec'') +-- <*> A.sample LinearSpec \ No newline at end of file diff --git a/src/GPT2/HListExtensions.hs b/src/GPT2/HListExtensions.hs index fe18a58..4a286a2 100644 --- a/src/GPT2/HListExtensions.hs +++ b/src/GPT2/HListExtensions.hs @@ -1,137 +1,123 @@ {-# LANGUAGE AllowAmbiguousTypes #-} -{- Enriching Torch.HList -} {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE InstanceSigs #-} +{-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE ParallelListComp #-} -{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE Strict #-} +{-# LANGUAGE StarIsType #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE NoStarIsType #-} -{-# OPTIONS_GHC -fconstraint-solver-iterations=0 #-} -{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} - -module GPT2.HListExtensions - ( HList, - HReplicateR, - HAppendListR, - HAppendFD, - ApplyAB, - applyAB, - HScanl, - hTail, - Nat2HNat, - HNat2Nat, - hLast, - HReplicate, - hScanl, - hScanlTail, - hReplicate, - hReplicateF - ) -where - -import Data.HList hiding (ApplyAB, applyAB) -import GHC.TypeLits -import Prelude hiding (cos, exp, sin) - --- | TODO: Originally, ApplyAB doesn't have this fundep. -class ApplyAB f a b | f a -> b where - applyAB :: f -> a -> b - --- class HLast (xs :: [Type]) (y :: Type) | xs -> y where --- hLast :: HList xs -> y - --- instance HLast '[x] x where --- hLast :: HList '[x] -> x --- hLast (x `HCons` HNil) = x - --- instance (HLast (b ': xs) y) => HLast (a ': b ': xs) y where --- hLast (_`HCons` xs) = hLast xs - --- class HScanr f z ls rs where --- -- | Correspond to `scanr :: (a -> b -> b) -> b -> [a] -> [b]` --- hScanr :: f -> z -> HList ls -> HList rs - --- instance (lz ~ '[z]) => HScanr f z '[] lz where --- hScanr _ z _ = HCons (z, HNil) - --- instance --- ( ApplyAB f (x, r) s, --- HScanr f z xs (r ': rs), --- srrs ~ (s ': r ': rs) --- ) => --- HScanr f z (x ': xs) srrs --- where --- hScanr f z (HCons (x, xs)) = --- case hScanr f z xs :: HList (r ': rs) of --- HCons (r, rs) -> HCons (applyAB f (x, r) :: s, HCons (r, rs)) - -class HScanl f z ls rs where - -- | Correspond to `scanl :: (b -> a -> b) -> b -> [a] -> [b]` - hScanl :: f -> z -> HList ls -> HList rs - -scanr' :: (a -> b -> b) -> b -> [a] -> [b] -scanr' _ ini [] = [ini] -scanr' f acc (x : xs) = f x y : ys - where - ys@(y : _) = scanr' f acc xs - -scanl' :: ((a, b) -> b) -> b -> [a] -> [b] -scanl' f acc xs = acc : [f (x, y) | x <- xs | y <- scanl' f acc xs] - -instance (lz ~ '[z]) => HScanl f z '[] lz where - hScanl _ z _ = z `HCons` HNil - -instance - ( ApplyAB f (z, x) s, - HScanl f s xs rs - ) => - HScanl f z (x ': xs) (z ': rs) - where - hScanl f z (x `HCons` xs) = - z `HCons` hScanl f (applyAB f (z, x) :: s) xs - -class HScanlTail f z ls rs where - -- | Correspond to `scanl :: (b -> a -> b) -> b -> [a] -> [b]` - hScanlTail :: f -> z -> HList ls -> HList rs -instance (lz ~ '[]) => HScanlTail f z '[] lz where - hScanlTail _ _ _ = HNil +module GPT2.HListExtensions where -instance - ( ApplyAB f (z, x) s, - HScanlTail f s xs rs +import Data.Kind (Type) +import GHC.TypeLits +import Unsafe.Coerce (unsafeCoerce) + +data family HList (l :: [Type]) + +data instance HList '[] = HNil + +data instance HList (x ': xs) = x `HCons` HList xs + +type family HReplicateR (n :: Nat) (e :: Type) :: [Type] where + HReplicateR 0 _ = '[] + HReplicateR n e = e ': HReplicateR (n - 1) e + +-- Ideally, we can write `hmap` as following: +-- However, because `HReplicateR` is non-injective, we don't necessarily know that +-- `HNil ~ HList (HReplicateR n a) => n = 0`. +-- Consequently, given `foo :: HList (HReplicateR n a)`, you simply can't +-- pattern match it. And hence, you can't write a `hmap`. +-- +-- Ideally: +-- hmap :: (a -> b) -> HList (HReplicateR n a) -> HList (HReplicateR n b) +-- hmap _ _ = _ +-- hmap f (HCons x xs) = HCons (f x) (hmap f xs) + +-- So we will explicitly establish connections between type variables. +-- We will not let the compiler infer structure from `HReplicateR`, +-- turning +-- `HMapC` says: Given elements types `a` and `b`, and list structures `as` and `bs`, +-- we define a way to transform an `HList as` into an `HList bs` using a function `a -> b`. +class HMapC a b as bs | a b as -> bs, a b bs -> as where + hmapC :: (a -> b) -> HList as -> HList bs + +-- An empty list maps to an empty list, regardless of the element types +instance HMapC a b '[] '[] where + hmapC _ HNil = HNil + +-- If we know how to map a list of as to bs, +-- then we can map a list with an a at the front to a list with a b at the front +instance (HMapC a b as bs) => HMapC a b (a ': as) (b ': bs) where + hmapC f (HCons x xs) = HCons (f x) (hmapC f xs) + +mapFst :: + forall a b n. + (HMapC (a, b) a (HReplicateR n (a, b)) (HReplicateR n a)) => + HList (HReplicateR n (a, b)) -> + HList (HReplicateR n a) +mapFst = hmapC @(a, b) fst + +l1 :: HList (HReplicateR 2 (Int, String)) +l1 = HCons (10, "asdf") (HCons (20, "zcv") HNil) + +l2 :: HList (HReplicateR 2 Int) +l2 = mapFst @Int @_ @2 l1 + +class HLast (xs :: [Type]) (y :: Type) | xs -> y where + hLast :: HList xs -> y + +instance HLast '[x] x where + hLast (x `HCons` HNil) = x + +instance (HLast (b ': xs) y) => HLast (a ': b ': xs) y where + hLast (_ `HCons` xs) = hLast xs + +class HTail (xs :: [Type]) (ys :: [Type]) | xs -> ys where + hTail :: HList xs -> HList ys + +instance HTail (a ': as) as where + hTail (_ `HCons` xs) = xs + +-- HScanlC applies a binary operation cumulatively over a list +-- The type signature mirrors scanl: (b -> a -> b) -> b -> [a] -> [b] +-- But for heterogeneous lists, we need to track the evolving accumulator type +class HScanlC a b as bs | a b as -> bs where + hScanlC :: (b -> a -> b) -> b -> HList as -> HList bs + +-- Base case: scanning an empty list produces a singleton list containing just the initial value +instance HScanlC a b '[] '[b] where + hScanlC _ b HNil = HCons b HNil + +instance (HScanlC a b as bs) => HScanlC a b (a ': as) (b ': bs) where + hScanlC :: (b -> a -> b) -> b -> HList (a : as) -> HList (b : bs) + hScanlC f acc (HCons x xs) = + let newAcc = f acc x + in HCons acc (hScanlC f newAcc xs) + +-- HScanlC +-- Integer Integer (HReplicateR n Int) (HReplicateR (1 + n) Int) + +l3 = hScanlC @Int (+) 0 l2 + +ok = hTail l3 + +foo :: + forall n. + ( HScanlC + Integer + Integer + (HReplicateR n Int) + (HReplicateR (1 + n) Int) ) => - HScanlTail f z (x ': xs) (z ': rs) - where - hScanlTail f z (x `HCons` xs) = - z `HCons` hScanlTail f (applyAB f (z, x) :: s) xs - -data AddFunc = AddFunc - -instance ApplyAB AddFunc (String, Int) String where - applyAB _ (x, y) = show x ++ show y - -type family Nat2HNat (n :: Nat) :: HNat where - Nat2HNat 0 = HZero - Nat2HNat n = HSucc (Nat2HNat (n - 1)) - -example :: IO () -example = do - -- Create an input list: [1, 2, 3] - let inputList = (1 :: Int) .*. (2 :: Int) .*. HNil - - -- Apply hScanr with AddFunc, starting value 0 - result :: HList (HReplicateR (Nat2HNat 3) String) - result = hScanl AddFunc "asdf" inputList - - ok = hTail result - -- Result should be equivalent to scanr (+) 0 [1,2,3] = [6,5,3,0] - print result + HList (HReplicateR n Int) -> + HList (HReplicateR (n + 1) Int) +foo = hScanlC (+) 0 diff --git a/src/GPT2/Loader.hs b/src/GPT2/Loader.hs index 5ab025f..172559c 100644 --- a/src/GPT2/Loader.hs +++ b/src/GPT2/Loader.hs @@ -8,7 +8,7 @@ import Control.Monad.IO.Class (liftIO) import Control.Monad.Trans.Maybe import GHC.Exts (IsList (fromList)) import GHC.Utils.Monad (zipWith3M) -import GPT2.Model +import GPT2.CachedModel import SafeTensors import qualified Torch as UT import qualified Torch.DType as D @@ -30,6 +30,8 @@ type NumAttnLayers = 12 type NumHeads = 12 +type HeadDim = 64 + type FFNDim = 3072 type PaddingIdx = 0 @@ -42,6 +44,7 @@ type Model = GPT2 NumAttnLayers NumHeads + HeadDim FFNDim PaddingIdx MaxSeqLen @@ -54,6 +57,7 @@ type ModelSpec = GPT2Spec NumAttnLayers NumHeads + HeadDim FFNDim PaddingIdx MaxSeqLen @@ -213,6 +217,7 @@ loadGPT2FromSafeTensors :: ( GPT2 NumAttnLayers NumHeads + HeadDim FFNDim PaddingIdx MaxSeqLen diff --git a/src/infer.hs b/src/infer.hs index 954123a..456c56d 100644 --- a/src/infer.hs +++ b/src/infer.hs @@ -9,6 +9,7 @@ {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} +{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} @@ -23,16 +24,20 @@ import Control.Monad.Trans.Class (lift) import Control.Monad.Trans.Maybe import Data.ByteString.Char8 (pack) import Data.Constraint +import Data.Dependent.Map +import Data.Functor.Identity import Data.Proxy import GHC.Int (Int64) import GHC.TypeLits +import GPT2.CachedModel +import GPT2.HListExtensions import GPT2.Loader -import GPT2.Model (transformerLM) import SafeTensors hiding (shape) import System.Environment (getArgs) import Tiktoken (r50k_base, toRanks) import qualified Torch as UT import qualified Torch.DType as D +import qualified Torch.Device as D import Torch.Internal.Cast (cast2) import qualified Torch.Internal.Managed.Native as ATen.Managed import Torch.Typed hiding (length, sample, transformerLM) @@ -90,7 +95,7 @@ infer :: Dict ((1 <=? numTokens) ~ 'True) -> Model -> [[Int64]] -> - Tensor ModelDevice UT.Float '[1, numTokens, VocabSize] + (Tensor ModelDevice UT.Float '[1, numTokens, VocabSize], GPT2ActivationCache ModelDevice 1 numTokens) infer Dict model tokens = transformerLM model $ UnsafeMkTensor @@ -100,7 +105,7 @@ infer Dict model tokens = $ UT.asTensor tokens -- | returns the next token prediction, already sampled --- [1, 1] indicates a batch size of 1 and a sequence length of 1 +-- [1, 1] indicates a batch size of 1 and a sequence length of 1 runInference :: FilePath -> MaybeT IO (Tensor ModelDevice 'Int64 [1, 1]) runInference [] = hoistMaybe Nothing runInference (fp :: FilePath) = do @@ -110,7 +115,7 @@ runInference (fp :: FilePath) = do withNat (length tokens) $ \(proxy :: Proxy numTokens) -> do dict <- hoistMaybe $ mkNumTokensProof @numTokens proxy - let result = infer @numTokens dict model [map fromIntegral tokens] + let (result, cache) = infer @numTokens dict model [Prelude.map fromIntegral tokens] lift $ sample result -- | Example: `cabal run -- /Users/jane.doe/model.safetensors`, must be absolute! From ccb7b007ac610601b770e2811ce4ec89369fc7fd Mon Sep 17 00:00:00 2001 From: Andrew Chen Date: Tue, 15 Apr 2025 00:24:12 -0400 Subject: [PATCH 2/3] Implemented transformer caching --- hech-interp.cabal | 4 +- src/GPT2.hs | 4 +- src/GPT2/CachedModel.hs | 182 ++++++----- src/GPT2/HListExtensions.hs | 25 ++ src/GPT2/Loader.hs | 4 +- src/GPT2/Model.hs | 602 ------------------------------------ src/infer.hs | 6 +- 7 files changed, 120 insertions(+), 707 deletions(-) delete mode 100644 src/GPT2/Model.hs diff --git a/hech-interp.cabal b/hech-interp.cabal index a450cd7..e1654f5 100644 --- a/hech-interp.cabal +++ b/hech-interp.cabal @@ -42,10 +42,10 @@ common base , dependent-map , constraints-extras , dependent-sum + , dependent-sum-template , constraints , tiktoken , megaparsec < 9.7 - , vector-sized , HList == 0.5.4.0 @@ -60,7 +60,7 @@ library exposed-modules: GPT2 GPT2.Loader - GPT2.Model + GPT2.CachedModel SafeTensors hs-source-dirs: src ghc-options: diff --git a/src/GPT2.hs b/src/GPT2.hs index 40fe235..e6b9f2f 100644 --- a/src/GPT2.hs +++ b/src/GPT2.hs @@ -1,4 +1,4 @@ -module GPT2 (module GPT2.Loader, module GPT2.Model) where +module GPT2 (module GPT2.Loader, module GPT2.CachedModel) where import GPT2.Loader -import GPT2.Model +import GPT2.CachedModel diff --git a/src/GPT2/CachedModel.hs b/src/GPT2/CachedModel.hs index 19769b4..a5f40aa 100644 --- a/src/GPT2/CachedModel.hs +++ b/src/GPT2/CachedModel.hs @@ -16,6 +16,7 @@ {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE NoStarIsType #-} +{-# LANGUAGE TemplateHaskell #-} {-# OPTIONS_GHC -fconstraint-solver-iterations=0 #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} @@ -27,7 +28,6 @@ import Data.Dependent.Sum ((==>)) import Data.Functor.Identity (Identity (..)) import Data.GADT.Compare import Data.Proxy -import qualified Data.Vector.Sized as V import Debug.Trace import GHC.Generics import GHC.TypeLits @@ -51,6 +51,8 @@ import Torch.Typed.NN.Sparse import Torch.Typed.Parameter import Torch.Typed.Tensor import Prelude hiding (cos, exp, sin) +import Data.GADT.Compare.TH (deriveGCompare, deriveGEq) +import Data.GADT.Show.TH (deriveGShow) -- TODO not sure if we actually need this -- import Data.GADT.Compare.TH (deriveGCompare, deriveGEq) @@ -60,51 +62,22 @@ residual f g x = f x >>= (\x' -> g (x `add` x')) traceTensor ten = trace (show . T.sliceDim 0 0 5 1 . T.select 0 0 . T.squeezeAll $ toDynamic ten) ten -data - ActivationCache - (device :: (D.DeviceType, Nat)) - (dtype :: D.DType) - (batchSize :: Nat) - (seqLen :: Nat) - (dmodel :: Nat) - (dhead :: Nat) - (nhead :: Nat) - (nlayers :: Nat) - a - where - Embed :: ActivationCache device dtype batchSize seqLen dmodel dhead nhead nlayers (Tensor device dtype '[batchSize, seqLen, dmodel]) - PosEmbed :: ActivationCache device dtype batchSize seqLen dmodel dhead nhead nlayers (Tensor device dtype '[batchSize, seqLen, dmodel]) - Blocks :: ActivationCache device dtype batchSize seqLen dmodel dhead nhead nlayers (HList (HReplicateR nLayers (DMap (BlockCache device dtype batchSize seqLen dmodel dhead nhead) Identity))) - LnFinal :: ActivationCache device dtype batchSize seqLen dmodel dhead nhead nlayers (DMap (LayerNormCache device dtype batchSize seqLen dmodel) Identity) - -deriving instance GEq (ActivationCache device dtype batchSize seqLen dmodel dhead nhead nlayers) - -deriving instance GCompare (ActivationCache device dtype batchSize seqLen dmodel dhead nhead nlayers) data - BlockCache + LayerNormCache (device :: (D.DeviceType, Nat)) (dtype :: D.DType) (batchSize :: Nat) (seqLen :: Nat) (dmodel :: Nat) - (dhead :: Nat) - (nhead :: Nat) a where - ResidPre :: BlockCache device dtype batchSize seqLen dmodel dhead nhead (Tensor device dtype '[batchSize, seqLen, dmodel]) - Ln1 :: BlockCache device dtype batchSize seqLen dmodel dhead nhead (DMap (LayerNormCache device dtype batchSize seqLen dmodel) Identity) - Attn :: BlockCache device dtype batchSize seqLen dmodel dhead nhead (DMap (AttentionCache device dtype batchSize seqLen dmodel dhead nhead) Identity) - AttnOut :: BlockCache device dtype batchSize seqLen dmodel dhead nhead (Tensor device dtype '[batchSize, seqLen, dmodel]) - ResidMid :: BlockCache device dtype batchSize seqLen dmodel dhead nhead (Tensor device dtype '[batchSize, seqLen, dmodel]) - Ln2 :: BlockCache device dtype batchSize seqLen dmodel dhead nhead (DMap (LayerNormCache device dtype batchSize seqLen dmodel) Identity) - MLP :: BlockCache device dtype batchSize seqLen dmodel dhead nhead (DMap (MLPCache device dtype batchSize seqLen dmodel) Identity) - MLPOut :: BlockCache device dtype batchSize seqLen dmodel dhead nhead (Tensor device dtype '[batchSize, seqLen, dmodel]) - ResidPost :: BlockCache device dtype batchSize seqLen dmodel dhead nhead (Tensor device dtype '[batchSize, seqLen, dmodel]) - -deriving instance GEq (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + -- | Scale is the std of the input (residual stream) + Scale :: LayerNormCache device dtype batchSize seqLen dmodel (Tensor device dtype '[batchSize, seqLen, 1]) + Normalized :: LayerNormCache device dtype batchSize seqLen dmodel (Tensor device dtype '[batchSize, seqLen, dmodel]) -deriving instance GCompare (BlockCache device dtype batchSize seqLen dmodel dhead nhead) +deriveGEq ''LayerNormCache +deriveGCompare ''LayerNormCache data AttentionCache @@ -125,56 +98,74 @@ data Z :: AttentionCache device dtype batchSize seqLen dmodel dhead nhead (Tensor device dtype '[batchSize, seqLen, nhead, dhead]) Result :: AttentionCache device dtype batchSize seqLen dmodel dhead nhead (Tensor device dtype '[batchSize, seqLen, nhead, dmodel]) -deriving instance GEq (AttentionCache device dtype batchSize seqLen dmodel dhead nhead) +deriveGEq ''AttentionCache +deriveGCompare ''AttentionCache + +data + MLPCache + (device :: (D.DeviceType, Nat)) + (dtype :: D.DType) + (batchSize :: Nat) + (seqLen :: Nat) + (ffnDim :: Nat) + a + where + -- Right before activation function + MLPpre :: MLPCache device dtype batchSize seqLen ffnDim (Tensor device dtype '[batchSize, seqLen, ffnDim]) + -- Right after activation function + MLPpost :: MLPCache device dtype batchSize seqLen ffnDim (Tensor device dtype '[batchSize, seqLen, ffnDim]) -deriving instance GCompare (AttentionCache device dtype batchSize seqLen dmodel dhead nhead) +deriveGEq ''MLPCache +deriveGCompare ''MLPCache data - LayerNormCache + BlockCache (device :: (D.DeviceType, Nat)) (dtype :: D.DType) (batchSize :: Nat) (seqLen :: Nat) (dmodel :: Nat) + (dhead :: Nat) + (nhead :: Nat) + (ffnDim :: Nat) a where - -- | Scale is the std of the input (residual stream) - Scale :: LayerNormCache device dtype batchSize seqLen dmodel (Tensor device dtype '[batchSize, seqLen, 1]) - Normalized :: LayerNormCache device dtype batchSize seqLen dmodel (Tensor device dtype '[batchSize, seqLen, dmodel]) + ResidPre :: BlockCache device dtype batchSize seqLen dmodel dhead nhead ffnDim(Tensor device dtype '[batchSize, seqLen, dmodel]) + Ln1 :: BlockCache device dtype batchSize seqLen dmodel dhead nhead ffnDim (DMap (LayerNormCache device dtype batchSize seqLen dmodel) Identity) + Attn :: BlockCache device dtype batchSize seqLen dmodel dhead nhead ffnDim (DMap (AttentionCache device dtype batchSize seqLen dmodel dhead nhead) Identity) + AttnOut :: BlockCache device dtype batchSize seqLen dmodel dhead nhead ffnDim (Tensor device dtype '[batchSize, seqLen, dmodel]) + ResidMid :: BlockCache device dtype batchSize seqLen dmodel dhead nhead ffnDim (Tensor device dtype '[batchSize, seqLen, dmodel]) + Ln2 :: BlockCache device dtype batchSize seqLen dmodel dhead nhead ffnDim (DMap (LayerNormCache device dtype batchSize seqLen dmodel) Identity) + MLP :: BlockCache device dtype batchSize seqLen dmodel dhead nhead ffnDim (DMap (MLPCache device dtype batchSize seqLen ffnDim) Identity) + MLPOut :: BlockCache device dtype batchSize seqLen dmodel dhead nhead ffnDim (Tensor device dtype '[batchSize, seqLen, dmodel]) + ResidPost :: BlockCache device dtype batchSize seqLen dmodel dhead nhead ffnDim (Tensor device dtype '[batchSize, seqLen, dmodel]) -deriving instance GEq (LayerNormCache device dtype batchSize seqLen dmodel) +deriveGEq ''BlockCache +deriveGCompare ''BlockCache -deriving instance GCompare (LayerNormCache device dtype batchSize seqLen dmodel) data - MLPCache + ActivationCache (device :: (D.DeviceType, Nat)) (dtype :: D.DType) (batchSize :: Nat) (seqLen :: Nat) + (dmodel :: Nat) + (dhead :: Nat) + (nhead :: Nat) + (nlayers :: Nat) (ffnDim :: Nat) a where - -- Right before activation function - MLPpre :: MLPCache device dtype batchSize seqLen dmodel (Tensor device dtype '[batchSize, seqLen, ffnDim]) - -- Right after activation function - MLPpost :: MLPCache device dtype batchSize seqLen dmodel (Tensor device dtype '[batchSize, seqLen, ffnDim]) - -deriving instance GEq (MLPCache device dtype batchSize seqLen dmodel) + Embed :: ActivationCache device dtype batchSize seqLen dmodel dhead nhead nlayers ffnDim (Tensor device dtype '[batchSize, seqLen, dmodel]) + PosEmbed :: ActivationCache device dtype batchSize seqLen dmodel dhead nhead nlayers ffnDim (Tensor device dtype '[batchSize, seqLen, dmodel]) + Blocks :: ActivationCache device dtype batchSize seqLen dmodel dhead nhead nlayers ffnDim (HList (HReplicateR nlayers (DMap (BlockCache device dtype batchSize seqLen dmodel dhead nhead ffnDim) Identity))) + LnFinal :: ActivationCache device dtype batchSize seqLen dmodel dhead nhead nlayers ffnDim (DMap (LayerNormCache device dtype batchSize seqLen dmodel) Identity) -deriving instance GCompare (MLPCache device dtype batchSize seqLen dmodel) +deriveGEq ''ActivationCache +deriveGCompare ''ActivationCache --- TODO not sure if we actually need this it's in the example --- deriveGEq ''CacheTag --- deriveGCompare ''CacheTag --- deriveGShow ''CacheTag --- deriveArgDict ''CacheTag - --- type GPT2BlockCache device batchSize seqLen = DMap (BlockCache device D.Float batchSize seqLen 768 64 12) Identity --- type GPT2AttentionCache device batchSize seqLen = DMap (AttentionCache device D.Float batchSize seqLen 768 64 12) Identity --- type GPT2LayerNormCache device batchSize seqLen = DMap (LayerNormCache device D.Float batchSize seqLen 768) Identity --- type GPT2MLPCache device batchSize seqLen = DMap (MLPCache device D.Float batchSize seqLen 768) Identity -type GPT2ActivationCache device batchSize seqLen = DMap (ActivationCache device D.Float batchSize seqLen 768 64 12 12) Identity +type GPT2ActivationCache device batchSize seqLen = DMap (ActivationCache device D.Float batchSize seqLen 768 64 12 12 (768 * 4)) Identity geluApproximate :: forall shape dtype device. @@ -352,10 +343,11 @@ data (batchSize :: Nat) (seqLen :: Nat) (dmodel :: Nat) + (ffnDim :: Nat) = TransformerMLPActivation { -- | ln2.normalized mlpLn :: DMap (LayerNormCache device dtype batchSize seqLen dmodel) Identity, - mlpCache :: DMap (MLPCache device dtype batchSize seqLen dmodel) Identity, + mlpCache :: DMap (MLPCache device dtype batchSize seqLen ffnDim) Identity, mlpOut :: Tensor device dtype '[batchSize, seqLen, dmodel], -- | resid_post mlpResult :: Tensor device dtype '[batchSize, seqLen, dmodel] @@ -379,7 +371,7 @@ transformerMLP :: -- | MLP model ADT for transformer TransformerMLP dmodel ffnDim dtype device -> Tensor device dtype '[batchSize, maxSeqLen, dmodel] -> -- input - TransformerMLPActivation device dtype batchSize maxSeqLen dmodel -- output + TransformerMLPActivation device dtype batchSize maxSeqLen dmodel ffnDim -- output transformerMLP TransformerMLP {..} x = TransformerMLPActivation {mlpLn = mlpLn, mlpCache = mlpCache, mlpOut = linear1Out, mlpResult = residPost} where mlpLn = layerNormForwardCached ln x @@ -488,7 +480,7 @@ transformerLayer :: forall (nhead :: Nat) (ffnDim :: Nat) (dmodel :: Nat) (dhead :: Nat) (seqLen :: Nat) (batchSize :: Nat) dtype device. ( 1 <= nhead, dmodel ~ (dhead * nhead), - All KnownNat '[dmodel, dmodel, dmodel, nhead, seqLen, batchSize, dhead], + All KnownNat '[dmodel, dmodel, dmodel, nhead, seqLen, batchSize, dhead, ffnDim], IsSuffixOf '[dmodel] '[batchSize, seqLen, dmodel], KnownDType dtype, dtype ~ SumDType dtype, @@ -508,7 +500,7 @@ transformerLayer :: -- | input Tensor device dtype '[batchSize, seqLen, dmodel] -> -- | transformer layer output representation - (Tensor device dtype '[batchSize, seqLen, dmodel], DMap (BlockCache device dtype batchSize seqLen dmodel dhead nhead) Identity) + (Tensor device dtype '[batchSize, seqLen, dmodel], DMap (BlockCache device dtype batchSize seqLen dmodel dhead nhead ffnDim) Identity) transformerLayer TransformerLayer {..} attentionMask inp = ( residPost, fromList @@ -690,7 +682,7 @@ transformerLM :: batchSize dtype device. - ( All KnownNat '[vocabSize, paddingIdx, dmodel, seqLen, batchSize, numAttnLayers, nhead, dhead], + ( All KnownNat '[vocabSize, paddingIdx, dmodel, seqLen, batchSize, numAttnLayers, nhead, dhead, ffnDim], StandardFloatingPointDTypeValidation device dtype, @@ -715,7 +707,7 @@ transformerLM :: (TransformerLayer dmodel nhead ffnDim dtype device) -- cur ( Tensor device dtype [batchSize, seqLen, dmodel], -- acc DMap - (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + (BlockCache device dtype batchSize seqLen dmodel dhead nhead ffnDim) Identity ) ( HReplicateR -- as @@ -726,7 +718,7 @@ transformerLM :: (1 + numAttnLayers) ( Tensor device dtype [batchSize, seqLen, dmodel], DMap - (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + (BlockCache device dtype batchSize seqLen dmodel dhead nhead ffnDim) Identity ) ), @@ -735,7 +727,7 @@ transformerLM :: (1 + numAttnLayers) ( Tensor device dtype [batchSize, seqLen, dmodel], DMap - (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + (BlockCache device dtype batchSize seqLen dmodel dhead nhead ffnDim) Identity ) ) @@ -743,32 +735,32 @@ transformerLM :: numAttnLayers ( Tensor device dtype [batchSize, seqLen, dmodel], DMap - (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + (BlockCache device dtype batchSize seqLen dmodel dhead nhead ffnDim) Identity ) ), HMapC ( Tensor device dtype [batchSize, seqLen, dmodel], DMap - (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + (BlockCache device dtype batchSize seqLen dmodel dhead nhead ffnDim) Identity ) ( DMap - (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + (BlockCache device dtype batchSize seqLen dmodel dhead nhead ffnDim) Identity ) ( HReplicateR numAttnLayers ( Tensor device dtype [batchSize, seqLen, dmodel], DMap - (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + (BlockCache device dtype batchSize seqLen dmodel dhead nhead ffnDim) Identity ) ) ( HReplicateR numAttnLayers ( DMap - (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + (BlockCache device dtype batchSize seqLen dmodel dhead nhead ffnDim) Identity ) ), @@ -777,13 +769,13 @@ transformerLM :: numAttnLayers ( Tensor device dtype [batchSize, seqLen, dmodel], DMap - (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + (BlockCache device dtype batchSize seqLen dmodel dhead nhead ffnDim) Identity ) ) ( Tensor device dtype [batchSize, seqLen, dmodel], DMap - (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + (BlockCache device dtype batchSize seqLen dmodel dhead nhead ffnDim) Identity ) ) => @@ -791,7 +783,7 @@ transformerLM :: Tensor device 'D.Int64 '[batchSize, seqLen] -> ( Tensor device dtype '[batchSize, seqLen, vocabSize], DMap - (ActivationCache device dtype batchSize seqLen dmodel dhead nhead numAttnLayers) + (ActivationCache device dtype batchSize seqLen dmodel dhead nhead numAttnLayers ffnDim) Identity ) transformerLM GPT2 {..} xTokens = do @@ -812,7 +804,7 @@ transformerLM GPT2 {..} xTokens = do maskedFill attentionMask (-(1 / 0) :: Double) $ zeros @'[batchSize, seqLen, seqLen] @dtype @device - let emp :: DMap (BlockCache device dtype batchSize seqLen dmodel dhead nhead) Identity + let emp :: DMap (BlockCache device dtype batchSize seqLen dmodel dhead nhead ffnDim) Identity emp = empty allOutputs :: HList @@ -820,7 +812,7 @@ transformerLM GPT2 {..} xTokens = do (numAttnLayers + 1) ( Tensor device dtype [batchSize, seqLen, dmodel], DMap - (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + (BlockCache device dtype batchSize seqLen dmodel dhead nhead ffnDim) Identity ) ) @@ -834,11 +826,11 @@ transformerLM GPT2 {..} xTokens = do hmapC @( Tensor device dtype '[batchSize, seqLen, dmodel], DMap - (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + (BlockCache device dtype batchSize seqLen dmodel dhead nhead ffnDim) Identity ) @( DMap - (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + (BlockCache device dtype batchSize seqLen dmodel dhead nhead ffnDim) Identity ) snd @@ -849,7 +841,7 @@ transformerLM GPT2 {..} xTokens = do (finalDist, fromList [Embed ==> x, PosEmbed ==> x', LnFinal ==> finalLnCache, Blocks ==> intermediateCaches]) instance - ( All KnownNat '[vocabSize, paddingIdx, dmodel, seqLen, batchSize, seqLen, numAttnLayers, nhead, dhead], + ( All KnownNat '[vocabSize, paddingIdx, dmodel, seqLen, batchSize, seqLen, numAttnLayers, nhead, dhead, ffnDim], StandardFloatingPointDTypeValidation device dtype, @@ -873,7 +865,7 @@ instance (TransformerLayer dmodel nhead ffnDim dtype device) -- cur ( Tensor device dtype [batchSize, seqLen, dmodel], -- acc DMap - (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + (BlockCache device dtype batchSize seqLen dmodel dhead nhead ffnDim) Identity ) ( HReplicateR -- as @@ -884,7 +876,7 @@ instance (1 + numAttnLayers) ( Tensor device dtype [batchSize, seqLen, dmodel], DMap - (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + (BlockCache device dtype batchSize seqLen dmodel dhead nhead ffnDim) Identity ) ), @@ -893,7 +885,7 @@ instance (1 + numAttnLayers) ( Tensor device dtype [batchSize, seqLen, dmodel], DMap - (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + (BlockCache device dtype batchSize seqLen dmodel dhead nhead ffnDim) Identity ) ) @@ -901,32 +893,32 @@ instance numAttnLayers ( Tensor device dtype [batchSize, seqLen, dmodel], DMap - (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + (BlockCache device dtype batchSize seqLen dmodel dhead nhead ffnDim) Identity ) ), HMapC ( Tensor device dtype [batchSize, seqLen, dmodel], DMap - (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + (BlockCache device dtype batchSize seqLen dmodel dhead nhead ffnDim) Identity ) ( DMap - (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + (BlockCache device dtype batchSize seqLen dmodel dhead nhead ffnDim) Identity ) ( HReplicateR numAttnLayers ( Tensor device dtype [batchSize, seqLen, dmodel], DMap - (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + (BlockCache device dtype batchSize seqLen dmodel dhead nhead ffnDim) Identity ) ) ( HReplicateR numAttnLayers ( DMap - (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + (BlockCache device dtype batchSize seqLen dmodel dhead nhead ffnDim) Identity ) ), @@ -935,13 +927,13 @@ instance numAttnLayers ( Tensor device dtype [batchSize, seqLen, dmodel], DMap - (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + (BlockCache device dtype batchSize seqLen dmodel dhead nhead ffnDim) Identity ) ) ( Tensor device dtype [batchSize, seqLen, dmodel], DMap - (BlockCache device dtype batchSize seqLen dmodel dhead nhead) + (BlockCache device dtype batchSize seqLen dmodel dhead nhead ffnDim) Identity ) ) => @@ -950,7 +942,7 @@ instance (Tensor device 'D.Int64 '[batchSize, seqLen]) ( Tensor device dtype '[batchSize, seqLen, vocabSize], DMap - (ActivationCache device dtype batchSize seqLen dmodel dhead nhead numAttnLayers) + (ActivationCache device dtype batchSize seqLen dmodel dhead nhead numAttnLayers ffnDim) Identity ) where diff --git a/src/GPT2/HListExtensions.hs b/src/GPT2/HListExtensions.hs index 4a286a2..365e02e 100644 --- a/src/GPT2/HListExtensions.hs +++ b/src/GPT2/HListExtensions.hs @@ -14,12 +14,14 @@ {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE PatternSynonyms #-} module GPT2.HListExtensions where import Data.Kind (Type) import GHC.TypeLits import Unsafe.Coerce (unsafeCoerce) +import GHC.Exts (IsList (..)) data family HList (l :: [Type]) @@ -27,6 +29,9 @@ data instance HList '[] = HNil data instance HList (x ': xs) = x `HCons` HList xs +pattern (:.) :: forall x (xs :: [Type]). x -> HList xs -> HList (x : xs) +pattern (:.) x xs = HCons x xs + type family HReplicateR (n :: Nat) (e :: Type) :: [Type] where HReplicateR 0 _ = '[] HReplicateR n e = e ': HReplicateR (n - 1) e @@ -121,3 +126,23 @@ foo :: HList (HReplicateR n Int) -> HList (HReplicateR (n + 1) Int) foo = hScanlC (+) 0 + + +instance IsList (Maybe (HList '[(a :: Type)])) where + type Item (Maybe (HList '[(a :: Type)])) = a + fromList [x] = liftA2 (:.) (Just x) (Just HNil) + fromList _ = Nothing + toList Nothing = [] + toList (Just (x :. HNil)) = [x] + +instance + ( IsList (Maybe (HList (a ': as))), + a ~ Item (Maybe (HList (a ': as))) + ) => + IsList (Maybe (HList ((a :: Type) ': a ': as))) + where + type Item (Maybe (HList (a ': a ': as))) = a + fromList (x : xs) = liftA2 (:.) (Just x) (fromList xs) + fromList _ = Nothing + toList Nothing = [] + toList (Just (x :. xs)) = x : toList (Just xs) \ No newline at end of file diff --git a/src/GPT2/Loader.hs b/src/GPT2/Loader.hs index 172559c..921567a 100644 --- a/src/GPT2/Loader.hs +++ b/src/GPT2/Loader.hs @@ -9,16 +9,16 @@ import Control.Monad.Trans.Maybe import GHC.Exts (IsList (fromList)) import GHC.Utils.Monad (zipWith3M) import GPT2.CachedModel +import GPT2.HListExtensions import SafeTensors import qualified Torch as UT import qualified Torch.DType as D -import Torch.HList import Torch.Typed hiding ( MultiheadAttention, TransformerLayer, TransformerMLP, keys, - transformerLM, + transformerLM, HList, HReplicateR ) import Prelude hiding (lookup) diff --git a/src/GPT2/Model.hs b/src/GPT2/Model.hs deleted file mode 100644 index 9969bfc..0000000 --- a/src/GPT2/Model.hs +++ /dev/null @@ -1,602 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveAnyClass #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE MultiParamTypeClasses #-} -{-# LANGUAGE RecordWildCards #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE Strict #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE NoStarIsType #-} -{-# OPTIONS_GHC -fconstraint-solver-iterations=0 #-} - -module GPT2.Model where - -import Control.Monad -import Data.Proxy -import Debug.Trace -import GHC.Generics -import GHC.TypeLits -import System.IO.Unsafe (unsafePerformIO) -import qualified Torch as T -import qualified Torch.DType as D -import qualified Torch.Device as D -import Torch.HList -import Torch.Internal.Cast (cast2) -import qualified Torch.Internal.Managed.Native as ATen.Managed -import Torch.NN (HasForward (..)) -import qualified Torch.NN as A -import Torch.Typed.Auxiliary -import Torch.Typed.Factories -import Torch.Typed.Functional hiding (linear, log, trace) -import Torch.Typed.NN.Linear -import Torch.Typed.NN.Normalization -import Torch.Typed.NN.Sparse -import Torch.Typed.Parameter -import Torch.Typed.Tensor -import Prelude hiding (cos, exp, sin) - -residual f g x = f x >>= (\x' -> g (x `add` x')) - -traceTensor ten = trace (show . T.sliceDim 0 0 5 1 . T.select 0 0 . T.squeezeAll $ toDynamic ten) ten - -geluApproximate :: - forall shape dtype device. - (GeluDTypeIsValid device dtype) => - Tensor device dtype shape -> - String -> - Tensor device dtype shape -geluApproximate _self _approximate = unsafePerformIO $ cast2 ATen.Managed.gelu_ts _self _approximate - --------------------------------------------------------------------------------- --- Relation-Aware Multi-Headed Attention Layer --------------------------------------------------------------------------------- - -data - MultiheadAttentionSpec - (numEmbeds :: Nat) - (numHeads :: Nat) - (dtype :: D.DType) - (device :: (D.DeviceType, Nat)) - where - MultiheadAttentionSpec :: - -- | spec for dropout - MultiheadAttentionSpec numEmbeds numHeads dtype device - deriving (Show, Eq) - -data - MultiheadAttention - (numEmbeds :: Nat) - (numHeads :: Nat) - (dtype :: D.DType) - (device :: (D.DeviceType, Nat)) - where - MultiheadAttention :: - { -- | packed in-projection for q, k, v - mhaQInProj :: Linear numEmbeds numEmbeds dtype device, - -- | in-projection for key - mhaKInProj :: Linear numEmbeds numEmbeds dtype device, - -- | in-projection for value - mhaVInProj :: Linear numEmbeds numEmbeds dtype device, - -- | out-projection - mhaOutProj :: Linear numEmbeds numEmbeds dtype device - } -> - MultiheadAttention numEmbeds numHeads dtype device - deriving (Show, Generic, Parameterized) - -multiheadAttention :: - forall numEmbeds numHeads inputSeqLen batchSize headDim dtype device. - ( 1 <= numHeads, - numEmbeds ~ (headDim * numHeads), - All KnownNat '[numEmbeds, numEmbeds, numEmbeds, numHeads, inputSeqLen, batchSize, headDim], - KnownDType dtype, - StandardFloatingPointDTypeValidation device dtype, - MatMulDTypeIsValid device dtype, - BasicArithmeticDTypeIsValid device dtype, - dtype ~ SumDType dtype, - SumDTypeIsValid device dtype, - KnownDevice device - ) => - -- | multi-head attention model ADT - MultiheadAttention numEmbeds numHeads dtype device -> - -- | optional attention mask - Maybe (Tensor device dtype '[batchSize, inputSeqLen, inputSeqLen]) -> - -- | query representation - Tensor device dtype '[batchSize, inputSeqLen, numEmbeds] -> - -- | key representation - Tensor device dtype '[batchSize, inputSeqLen, numEmbeds] -> - -- | value representation - Tensor device dtype '[batchSize, inputSeqLen, numEmbeds] -> - -- | attention and attention averaged over heads - IO (Tensor device dtype '[batchSize, inputSeqLen, numEmbeds]) -multiheadAttention MultiheadAttention {..} attentionMask query key value = do - let weights = - softmax @3 - . _maskAttention - $ _attentionWeights - return $ _attention weights - where - _attentionWeights = - let scaling = Prelude.sqrt . fromIntegral $ natValI @headDim :: Double - q = reshape' . forward mhaQInProj $ query - k = reshape' . forward mhaKInProj $ key - weights = divScalar scaling $ matmul q (transpose @2 @3 k) - in weights - _maskAttention attentionWeights = - case attentionMask of - Nothing -> attentionWeights - Just am -> attentionWeights `add` unsqueeze @1 am - _attention attentionWeights = - let v = reshape' . forward mhaVInProj $ value - attention = transpose @1 @2 $ matmul attentionWeights v - in forward mhaOutProj . reshape @'[batchSize, inputSeqLen, numEmbeds] $ attention - reshape' :: - forall inputSeqLen'. - (KnownNat inputSeqLen') => - Tensor device dtype '[batchSize, inputSeqLen', numEmbeds] -> - Tensor device dtype '[batchSize, numHeads, inputSeqLen', headDim] - reshape' t' = transpose @1 @2 $ reshape @'[batchSize, inputSeqLen', numHeads, headDim] t' - -instance - ( All KnownNat '[numEmbeds, numEmbeds, numEmbeds, numHeads], - KnownDType dtype, - KnownDevice device, - RandDTypeIsValid device dtype - ) => - A.Randomizable - (MultiheadAttentionSpec numEmbeds numHeads dtype device) - (MultiheadAttention numEmbeds numHeads dtype device) - where - sample MultiheadAttentionSpec = - MultiheadAttention - <$> A.sample LinearSpec - <*> A.sample LinearSpec - <*> A.sample LinearSpec - <*> A.sample LinearSpec - --------------------------------------------------------------------------------- --- Transformer MLP Layer --------------------------------------------------------------------------------- - -data - TransformerMLPSpec - (numEmbeds :: Nat) - (ffnDim :: Nat) - (dtype :: D.DType) - (device :: (D.DeviceType, Nat)) - where - TransformerMLPSpec :: - forall numEmbeds ffnDim dtype device. - { -- | epsilon for layer norm - epsSpec :: Double - } -> - TransformerMLPSpec numEmbeds ffnDim dtype device - deriving (Show, Eq) - -data - TransformerMLP - (numEmbeds :: Nat) - (ffnDim :: Nat) - (dtype :: D.DType) - (device :: (D.DeviceType, Nat)) - where - TransformerMLP :: - forall numEmbeds ffnDim dtype device. - { -- | first fully connected layer - linear0 :: Linear numEmbeds ffnDim dtype device, - -- | second fully connected layer - linear1 :: Linear ffnDim numEmbeds dtype device, - -- | layer norm - ln :: LayerNorm '[numEmbeds] dtype device - } -> - TransformerMLP numEmbeds ffnDim dtype device - deriving (Show, Generic, Parameterized) - -transformerMLP :: - forall numEmbeds ffnDim maxSeqLen batchSize dtype device. - ( BasicArithmeticDTypeIsValid device dtype, - StandardFloatingPointDTypeValidation device dtype, - KnownNat numEmbeds, - GeluDTypeIsValid device dtype, - IsSuffixOf '[numEmbeds] '[maxSeqLen, batchSize, numEmbeds] - ) => - -- | MLP model ADT for transformer - TransformerMLP numEmbeds ffnDim dtype device -> - Tensor device dtype '[maxSeqLen, batchSize, numEmbeds] -> -- input - IO (Tensor device dtype '[maxSeqLen, batchSize, numEmbeds]) -- output -transformerMLP TransformerMLP {..} x = - return - . (`add` x) - . forward linear1 - . (`geluApproximate` "tanh") - . forward linear0 - $ forward ln x - -instance - ( All KnownNat '[numEmbeds, ffnDim], - KnownDType dtype, - KnownDevice device, - RandDTypeIsValid device dtype - ) => - A.Randomizable - (TransformerMLPSpec numEmbeds ffnDim dtype device) - (TransformerMLP numEmbeds ffnDim dtype device) - where - sample TransformerMLPSpec {..} = - TransformerMLP - <$> A.sample LinearSpec - <*> A.sample LinearSpec - <*> A.sample (LayerNormSpec epsSpec) - --------------------------------------------------------------------------------- --- Relation-Aware Transformer Layer --------------------------------------------------------------------------------- - -data - TransformerLayerSpec - (numEmbeds :: Nat) - (numHeads :: Nat) - (ffnDim :: Nat) - (dtype :: D.DType) - (device :: (D.DeviceType, Nat)) - where - TransformerLayerSpec :: - forall numEmbeds numHeads ffnDim dtype device. - { mhaSpec :: MultiheadAttentionSpec numEmbeds numHeads dtype device, - epsSpec' :: Double, - mlpSpec :: TransformerMLPSpec numEmbeds ffnDim dtype device - } -> - TransformerLayerSpec numEmbeds numHeads ffnDim dtype device - deriving (Show, Eq) - -data - TransformerLayer - (numEmbeds :: Nat) - (numHeads :: Nat) - (ffnDim :: Nat) - (dtype :: D.DType) - (device :: (D.DeviceType, Nat)) - where - TransformerLayer :: - forall numEmbeds numHeads ffnDim dtype device. - { -- | multi-head attention - transformerLayer_mha :: MultiheadAttention numEmbeds numHeads dtype device, - -- | layer norm - transformerLayer_ln :: LayerNorm '[numEmbeds] dtype device, - -- | MLP - transformerLayer_mlp :: TransformerMLP numEmbeds ffnDim dtype device - } -> - TransformerLayer numEmbeds numHeads ffnDim dtype device - deriving (Show, Generic, Parameterized) - -transformerLayer :: - forall (numHeads :: Nat) (ffnDim :: Nat) (numEmbeds :: Nat) (headDim :: Nat) (inputSeqLen :: Nat) (batchSize :: Nat) dtype device. - ( 1 <= numHeads, - numEmbeds ~ (headDim * numHeads), - All KnownNat '[numEmbeds, numEmbeds, numEmbeds, numHeads, inputSeqLen, batchSize, headDim], - IsSuffixOf '[numEmbeds] '[batchSize, inputSeqLen, numEmbeds], - KnownDType dtype, - dtype ~ SumDType dtype, - StandardFloatingPointDTypeValidation device dtype, - GeluDTypeIsValid device dtype, - MatMulDTypeIsValid device dtype, - BasicArithmeticDTypeIsValid device dtype, - SumDTypeIsValid device dtype, - KnownDevice device - ) => - -- | transformer layer model ADT - TransformerLayer numEmbeds numHeads ffnDim dtype device -> - -- | optional attention mask - Maybe (Tensor device dtype '[batchSize, inputSeqLen, inputSeqLen]) -> - -- | query representation - Tensor device dtype '[batchSize, inputSeqLen, numEmbeds] -> - -- | key representation - Tensor device dtype '[batchSize, inputSeqLen, numEmbeds] -> - -- | value representation - Tensor device dtype '[batchSize, inputSeqLen, numEmbeds] -> - -- | transformer layer output representation - IO (Tensor device dtype '[batchSize, inputSeqLen, numEmbeds]) -transformerLayer TransformerLayer {..} attentionMask query key value = - let key' = forward transformerLayer_ln key - value' = forward transformerLayer_ln value - f query' = multiheadAttention transformerLayer_mha attentionMask query' key' value' - in -- _ <- print . T.sliceDim 0 0 5 1 . T.select 0 0 . T.squeezeAll . toDynamic $ fst r - do - result <- (query `add`) <$> f (forward transformerLayer_ln query) - transformerMLP transformerLayer_mlp result - -instance - ( All KnownNat '[numEmbeds, numEmbeds, numEmbeds, numHeads, ffnDim], - KnownDType dtype, - KnownDevice device, - RandDTypeIsValid device dtype - ) => - A.Randomizable - (TransformerLayerSpec numEmbeds numHeads ffnDim dtype device) - (TransformerLayer numEmbeds numHeads ffnDim dtype device) - where - sample TransformerLayerSpec {..} = - TransformerLayer - <$> A.sample mhaSpec - <*> A.sample (LayerNormSpec epsSpec') - <*> A.sample mlpSpec - --------------------------------------------------------------------------------- --- Transformer Language Model (GPT-2) --------------------------------------------------------------------------------- - -data - GPT2Spec - (numAttnLayers :: Nat) - (numHeads :: Nat) - (ffnDim :: Nat) - (paddingIdx :: Nat) - (maxSeqLen :: Nat) - (vocabSize :: Nat) - (numEmbeds :: Nat) - (dtype :: D.DType) - (device :: (D.DeviceType, Nat)) - where - GPT2Spec :: - forall numAttnLayers numHeads ffnDim paddingIdx maxSeqLen vocabSize numEmbeds dtype device. - { -- | spec for each and every transformer layer - lmLayerSpec :: TransformerLayerSpec numEmbeds numHeads ffnDim dtype device, - epsSpec'' :: Double - } -> - GPT2Spec numAttnLayers numHeads ffnDim paddingIdx maxSeqLen vocabSize numEmbeds dtype device - deriving (Show, Eq) - -data - GPT2 - (numAttnLayers :: Nat) - (numHeads :: Nat) - (ffnDim :: Nat) - (paddingIdx :: Nat) - (maxSeqLen :: Nat) - (vocabSize :: Nat) - (numEmbeds :: Nat) - (dtype :: D.DType) - (device :: (D.DeviceType, Nat)) - where - GPT2 :: - forall numAttnLayers numHeads ffnDim paddingIdx maxSeqLen vocabSize numEmbeds dtype device. - { -- | token embedding - tEmbedding :: Embedding ('Just paddingIdx) vocabSize numEmbeds 'Learned dtype device, - -- | positional embedding - tPosEmbedding :: Embedding 'Nothing maxSeqLen numEmbeds 'Constant dtype device, - -- | transformer layers - tLayers :: HList (HReplicateR numAttnLayers (TransformerLayer numEmbeds numHeads ffnDim dtype device)), - -- | final layer norm - tFinalLN :: LayerNorm '[numEmbeds] dtype device, - -- | final output projection - tProj :: Linear numEmbeds vocabSize dtype device - } -> - GPT2 numAttnLayers numHeads ffnDim paddingIdx maxSeqLen vocabSize numEmbeds dtype device - deriving (Generic) - -deriving instance - ( Show - ( HList - ( HReplicateR - numAttnLayers - ( TransformerLayer - numEmbeds - numHeads - ffnDim - dtype - device - ) - ) - ) - ) => - Show (GPT2 numAttnLayers numHeads ffnDim paddingIdx maxSeqLen vocabSize numEmbeds dtype device) - -instance - ( layers - ~ ( HReplicateR - numAttnLayers - ( TransformerLayer - numEmbeds - numHeads - ffnDim - dtype - device - ) - ), - Parameterized - ( HList - layers - ), - HAppendFD - (Parameters (HList layers)) - '[ Parameter device dtype '[numEmbeds], - Parameter device dtype '[numEmbeds], - Parameter device dtype '[vocabSize, numEmbeds], - Parameter device dtype '[vocabSize] - ] - ( Parameters (HList layers) - ++ '[ Parameter device dtype '[numEmbeds], - Parameter device dtype '[numEmbeds], - Parameter device dtype '[vocabSize, numEmbeds], - Parameter device dtype '[vocabSize] - ] - ) - ) => - Parameterized (GPT2 numAttnLayers numHeads ffnDim paddingIdx maxSeqLen vocabSize numEmbeds dtype device) - -data - FoldLayers - (batchSize :: Nat) - (inputSeqLen :: Nat) - (dtype :: D.DType) - (device :: (D.DeviceType, Nat)) = FoldLayers - { -- | optional attention mask - flAttentionMask :: Maybe (Tensor device dtype '[batchSize, inputSeqLen, inputSeqLen]) - } - -instance - ( 1 <= numHeads, - numEmbeds ~ (headDim * numHeads), - All KnownNat '[numEmbeds, numHeads, inputSeqLen, batchSize, headDim], - IsSuffixOf '[numEmbeds] '[batchSize, inputSeqLen, numEmbeds], - KnownDType dtype, - StandardFloatingPointDTypeValidation device dtype, - MatMulDTypeIsValid device dtype, - BasicArithmeticDTypeIsValid device dtype, - GeluDTypeIsValid device dtype, - dtype ~ SumDType dtype, - SumDTypeIsValid device dtype, - KnownDevice device - ) => - Apply' - (FoldLayers batchSize inputSeqLen dtype device) - ( TransformerLayer numEmbeds numHeads ffnDim dtype device, - IO (Tensor device dtype '[batchSize, inputSeqLen, numEmbeds]) - ) - (IO (Tensor device dtype '[batchSize, inputSeqLen, numEmbeds])) - where - apply' FoldLayers {..} (layer, mx) = do - x <- mx - transformerLayer layer flAttentionMask x x x - -transformerLM :: - forall - numAttnLayers - numHeads - ffnDim - paddingIdx - vocabSize - numEmbeds - inputSeqLen - maxSeqLen - batchSize - dtype - device. - ( All KnownNat '[paddingIdx, numEmbeds, inputSeqLen, batchSize], - IsSuffixOf '[numEmbeds] '[batchSize, inputSeqLen, numEmbeds], - paddingIdx + 1 <= vocabSize, - 1 <= inputSeqLen, - HFoldrM - IO - (FoldLayers batchSize inputSeqLen dtype device) - (Tensor device dtype '[batchSize, inputSeqLen, numEmbeds]) - (HReplicateR numAttnLayers (TransformerLayer numEmbeds numHeads ffnDim dtype device)) - (Tensor device dtype '[batchSize, inputSeqLen, numEmbeds]), - BasicArithmeticDTypeIsValid device dtype, - ComparisonDTypeIsValid device dtype, - ComparisonDTypeIsValid device 'D.Int64, - KnownDType dtype, - KnownDevice device - ) => - GPT2 numAttnLayers numHeads ffnDim paddingIdx maxSeqLen vocabSize numEmbeds dtype device -> - Tensor device 'D.Int64 '[batchSize, inputSeqLen] -> - IO (Tensor device dtype '[batchSize, inputSeqLen, vocabSize]) -transformerLM GPT2 {..} xTokens = do - let x = embed tEmbedding xTokens - positions = - expand @'[batchSize, inputSeqLen, numEmbeds] True - -- . (\pos_emb -> trace (show . T.select 0 0 $ toDynamic pos_emb) pos_emb) - . embed tPosEmbedding - . Torch.Typed.Tensor.toDType @D.Int64 - . linspace @inputSeqLen (0 :: Int) - $ natValI @(inputSeqLen - 1) - let x' = x `add` positions - let attentionMask = - unsqueeze @0 - . Torch.Typed.Tensor.toDType @D.Bool - . triu 1 - $ ones @'[inputSeqLen, inputSeqLen] @D.Int8 @device - attentionMask' = - pure . maskedFill attentionMask (-1 / 0 :: Double) $ - zeros @'[batchSize, inputSeqLen, inputSeqLen] @dtype @device - -- _ <- print $ shape x - -- _ <- print (T.select 0 0 . T.squeezeAll $ toDynamic x) - y <- hfoldrM (FoldLayers attentionMask') x' tLayers - return - -- (\final -> trace (show . T.sliceDim 0 0 5 1 . T.select 0 0 . T.squeezeAll $ toDynamic final) final) $ - -- (\fin -> trace (show . T.select 0 0 . T.squeezeAll $ toDynamic fin) forward tProj fin) $ - . forward tProj - $ forward tFinalLN y - -instance - ( All KnownNat '[paddingIdx, numEmbeds, inputSeqLen, batchSize, inputSeqLen], - IsSuffixOf '[numEmbeds] '[batchSize, inputSeqLen, numEmbeds], - paddingIdx + 1 <= vocabSize, - 1 <= inputSeqLen, - HFoldrM - IO - (FoldLayers batchSize inputSeqLen dtype device) - (Tensor device dtype '[batchSize, inputSeqLen, numEmbeds]) - (HReplicateR numAttnLayers (TransformerLayer numEmbeds numHeads ffnDim dtype device)) - (Tensor device dtype '[batchSize, inputSeqLen, numEmbeds]), - BasicArithmeticDTypeIsValid device dtype, - ComparisonDTypeIsValid device dtype, - ComparisonDTypeIsValid device 'D.Int64, - KnownDType dtype, - KnownDevice device - ) => - HasForward (GPT2 numAttnLayers numHeads ffnDim paddingIdx inputSeqLen vocabSize numEmbeds dtype device) (Tensor device 'D.Int64 '[batchSize, inputSeqLen]) (Tensor device dtype '[batchSize, inputSeqLen, vocabSize]) - where - forward model input = unsafePerformIO $ transformerLM model input - forwardStoch model input = transformerLM model input - -sinusoidal :: - forall vocabSize numEmbeds device. - ( All KnownNat '[vocabSize, numEmbeds], - 1 <= vocabSize, - 1 <= Div numEmbeds 2, - (Div numEmbeds 2 * 2) ~ numEmbeds, - StandardFloatingPointDTypeValidation device 'D.Float, - BasicArithmeticDTypeIsValid device 'D.Float, - KnownDevice device - ) => - Tensor device 'D.Float '[vocabSize, numEmbeds] -sinusoidal = - let positions = - unsqueeze @1 - . linspace @vocabSize (0 :: Int) - $ natValI @(vocabSize - 1) - scalingFactors = - exp - . mulScalar (-log (10000 :: Double) / (fromInteger . natVal $ Proxy @(Div numEmbeds 2))) - . linspace @(Div numEmbeds 2) (0 :: Int) - $ natValI @((Div numEmbeds 2) - 1) - radians = mul positions scalingFactors - weights = stack @2 (sin radians :. cos radians :. HNil) - in reshape weights - -instance - ( paddingIdx <= vocabSize, - 1 <= maxSeqLen, - 1 <= vocabSize - paddingIdx, - 1 <= Div numEmbeds 2, - (((vocabSize - paddingIdx) - 1) + (1 + paddingIdx)) ~ vocabSize, - (Div numEmbeds 2 * 2) ~ numEmbeds, - All KnownNat '[ffnDim, paddingIdx, vocabSize, maxSeqLen, numEmbeds], - HReplicate numAttnLayers (TransformerLayerSpec numEmbeds numHeads ffnDim dtype device), - A.Randomizable - (HList (HReplicateR numAttnLayers (TransformerLayerSpec numEmbeds numHeads ffnDim dtype device))) - (HList (HReplicateR numAttnLayers (TransformerLayer numEmbeds numHeads ffnDim dtype device))), - KnownDType dtype, - RandDTypeIsValid device dtype, - StandardFloatingPointDTypeValidation device 'D.Float, - BasicArithmeticDTypeIsValid device 'D.Float, - KnownDevice device - ) => - A.Randomizable - (GPT2Spec numAttnLayers numHeads ffnDim paddingIdx maxSeqLen vocabSize numEmbeds dtype device) - (GPT2 numAttnLayers numHeads ffnDim paddingIdx maxSeqLen vocabSize numEmbeds dtype device) - where - sample GPT2Spec {..} = - GPT2 - <$> A.sample (LearnedEmbeddingWithRandomInitSpec @('Just paddingIdx)) - <*> A.sample (ConstEmbeddingSpec @'Nothing (Torch.Typed.Tensor.toDType sinusoidal)) - <*> A.sample (hreplicate @numAttnLayers lmLayerSpec) - <*> A.sample (LayerNormSpec epsSpec'') - <*> A.sample LinearSpec diff --git a/src/infer.hs b/src/infer.hs index 456c56d..c8cf909 100644 --- a/src/infer.hs +++ b/src/infer.hs @@ -24,24 +24,21 @@ import Control.Monad.Trans.Class (lift) import Control.Monad.Trans.Maybe import Data.ByteString.Char8 (pack) import Data.Constraint -import Data.Dependent.Map -import Data.Functor.Identity import Data.Proxy import GHC.Int (Int64) import GHC.TypeLits import GPT2.CachedModel -import GPT2.HListExtensions import GPT2.Loader import SafeTensors hiding (shape) import System.Environment (getArgs) import Tiktoken (r50k_base, toRanks) import qualified Torch as UT import qualified Torch.DType as D -import qualified Torch.Device as D import Torch.Internal.Cast (cast2) import qualified Torch.Internal.Managed.Native as ATen.Managed import Torch.Typed hiding (length, sample, transformerLM) import Unsafe.Coerce (unsafeCoerce) +import Data.Dependent.Map ((!)) type family MultinomialCheck (n :: Nat) (shape :: [Nat]) (dim :: Nat) (sat :: Maybe Nat) (result :: Maybe a) :: a where MultinomialCheck _ shape dim _ Nothing = DimOutOfBound shape dim @@ -116,6 +113,7 @@ runInference (fp :: FilePath) = do do dict <- hoistMaybe $ mkNumTokensProof @numTokens proxy let (result, cache) = infer @numTokens dict model [Prelude.map fromIntegral tokens] + a = cache ! Blocks lift $ sample result -- | Example: `cabal run -- /Users/jane.doe/model.safetensors`, must be absolute! From 8a5dbd212247e451a24e3edac559d51cd1b9269e Mon Sep 17 00:00:00 2001 From: Andrew Chen Date: Tue, 15 Apr 2025 00:50:39 -0400 Subject: [PATCH 3/3] nix now builds, removed HList dep that I'm not using --- hech-interp.cabal | 1 - stack.yaml | 1 - 2 files changed, 2 deletions(-) diff --git a/hech-interp.cabal b/hech-interp.cabal index e1654f5..aca9f4a 100644 --- a/hech-interp.cabal +++ b/hech-interp.cabal @@ -46,7 +46,6 @@ common base , constraints , tiktoken , megaparsec < 9.7 - , HList == 0.5.4.0 common binary-base diff --git a/stack.yaml b/stack.yaml index 6585a18..e00f18c 100644 --- a/stack.yaml +++ b/stack.yaml @@ -20,4 +20,3 @@ extra-deps: - psqueues-0.2.8.0 - hasktorch-0.2.1.3@sha256:ccb99b703f2c8b2ea1c5f4ad2482bb2a451804b3ddee0499b8f610ea401f08c3,10794 - tiktoken-1.0.3@sha256:5d9b67982d1c24d0556f94ab1f542948ef9d73cc14bd301d5536390832df7858,2401 - - hackage: HList-0.5.4.0