diff --git a/.gitignore b/.gitignore index 2fef60161..36e5ccafd 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ /build/ /install/ mlton-*.tgz +.vscode/ diff --git a/basis-library/schedulers/spork/ForkJoin.sml b/basis-library/schedulers/spork/ForkJoin.sml index 16623aa21..93d28c0bb 100644 --- a/basis-library/schedulers/spork/ForkJoin.sml +++ b/basis-library/schedulers/spork/ForkJoin.sml @@ -3,6 +3,7 @@ struct datatype TokenPolicy = datatype Scheduler.TokenPolicy val spork = Scheduler.SporkJoin.spork + val primSporkChoose = Scheduler.primSporkChoose fun par (f: unit -> 'a, g: unit -> 'b): 'a * 'b = spork { @@ -54,26 +55,14 @@ struct end -signature LOOP_INDEX = -sig - type idx - type t = idx - - val fromInt: int -> idx - val toInt: idx -> int - - val increment: idx -> idx - val midpoint: idx * idx -> idx - val equal: idx * idx -> bool -end - - functor ManagedLoops (LoopIndex: LOOP_INDEX) :> sig val pareduce: (int * int) -> 'a -> (int * 'a -> 'a) -> ('a * 'a -> 'a) -> 'a val pareduceBreakExn: (int * int) -> 'a -> (('a -> exn) * int * 'a -> 'a) -> ('a * 'a -> 'a) -> 'a val reducem: ('a * 'a -> 'a) -> 'a -> (int * int) -> (int -> 'a) -> 'a val parform: (int * int) -> (int -> unit) -> unit + val seqLoop: (int * int) -> (int -> unit) -> unit + val seqReduce: ('a * 'a -> 'a) -> 'a -> (int * int) -> (int -> 'a) -> 'a end = struct type idx = LoopIndex.t @@ -160,52 +149,26 @@ struct fun __inline_always__ parform (lo: int, hi: int) (f: int -> unit) : unit = reducem (fn _ => ()) () (lo, hi) f -end -functor LoopIndexFromWord(WordImpl: WORD) :> LOOP_INDEX = -struct - type idx = WordImpl.word - type t = idx + fun __inline_always__ seqLoop (lo: int, hi: int) (f: int -> unit) : unit = + let + fun loop (i: idx, j: idx) : unit = + if LoopIndex.equal (i, j) then () + else (__inline_always__ f (LoopIndex.toInt i); loop (LoopIndex.increment i, j)) + in + loop (LoopIndex.fromInt (Int.min (lo, hi)), LoopIndex.fromInt hi) + end - fun __inline_always__ toInt (w: idx) = __inline_always__ WordImpl.toIntX w - fun __inline_always__ fromInt i = __inline_always__ WordImpl.fromInt i - fun __inline_always__ midpoint (i: idx, j: idx) = + fun __inline_always__ seqReduce (combine: 'a * 'a -> 'a) (zero: 'a) (lo: int, hi: int) (f: int -> 'a) : 'a = let - (* This way is broken! *) - (* val mid = WordImpl.~>> (WordImpl.+ (i, j), 0w1) *) - - val range_size = WordImpl.+ (j, WordImpl.~ i) - val mid = WordImpl.+ (i, WordImpl.div (range_size, WordImpl.fromInt 2)) + fun loop (acc: 'a) (i: idx, j: idx) : 'a = + if LoopIndex.equal (i, j) then acc + else loop (__inline_always__ combine (acc, __inline_always__ f (LoopIndex.toInt i))) (LoopIndex.increment i, j) in - (* If using a different midpoint calculation, consider uncommenting - * the following for debugging/testing. - *) - - (* if toInt i <= toInt mid andalso toInt mid <= toInt j then - () - else - ( print - ( "ERROR: schedulers/spork/ForkJoin.sml: bug! midpoint failure: " - ^ Int.toString (toInt i) - ^ " " - ^ Int.toString (toInt mid) - ^ " " - ^ Int.toString (toInt j) - ^ "\n" - ) - - ; OS.Process.exit OS.Process.failure - ); *) - - mid + loop zero (LoopIndex.fromInt (Int.min (lo, hi)), LoopIndex.fromInt hi) end - - fun __inline_always__ increment (i: idx) = - WordImpl.+ (i, fromInt 1) - - fun __inline_always__ equal (i: idx, j: idx) = (i = j) end @@ -221,11 +184,17 @@ sig val pareduceBreakExn: (int * int) -> 'a -> (('a -> exn) * int * 'a -> 'a) -> ('a * 'a -> 'a) -> 'a val reducem: ('a * 'a -> 'a) -> 'a -> (int * int) -> (int -> 'a) -> 'a + val reduce: ('a * 'a -> 'a) -> 'a -> (int * int) -> (int -> 'a) -> 'a + val reducemDefault: ('a * 'a -> 'a) -> 'a -> (int * int) -> (int -> 'a) -> 'a val parform: (int * int) -> (int -> unit) -> unit + val parformDefault: (int * int) -> (int -> unit) -> unit val parfor: int -> (int * int) -> (int -> unit) -> unit val alloc: int -> 'a array + val seqLoop: (int * int) -> (int -> unit) -> unit + val seqReduce: ('a * 'a -> 'a) -> 'a -> (int * int) -> (int -> 'a) -> 'a + val idleTimeSoFar: unit -> Time.time val workTimeSoFar: unit -> Time.time val maxForkDepthSoFar: unit -> int @@ -254,6 +223,11 @@ struct val equal = op= end) + structure Unrolled8 = UnrolledLoops(Word8) + structure Unrolled16 = UnrolledLoops(Word16) + structure Unrolled32 = UnrolledLoops(Word32) + structure Unrolled64 = UnrolledLoops(Word64) + structure Pareduce = Int_ChooseFromInt (struct type 'a t = (int * int) -> 'a -> (int * 'a -> 'a) -> ('a * 'a -> 'a) -> 'a @@ -294,8 +268,144 @@ struct val fIntInf = LoopsInt.parform end) - val pareduce = Pareduce.f - val pareduceBreakExn = PareduceBreakExn.f - val reducem = Reducem.f - val parform = Parform.f + structure UnrolledPareduce = + Int_ChooseFromInt (struct + type 'a t = (int * int) -> 'a -> (int * 'a -> 'a) -> ('a * 'a -> 'a) -> 'a + val fInt8 = Unrolled8.pareduce + val fInt16 = Unrolled16.pareduce + val fInt32 = Unrolled32.pareduce + val fInt64 = Unrolled64.pareduce + val fIntInf = Unrolled64.pareduce + end) + + structure UnrolledPareduceBreakExn = + Int_ChooseFromInt (struct + type 'a t = (int * int) -> 'a -> (('a -> exn) * int * 'a -> 'a) -> ('a * 'a -> 'a) -> 'a + val fInt8 = Unrolled8.pareduceBreakExn + val fInt16 = Unrolled16.pareduceBreakExn + val fInt32 = Unrolled32.pareduceBreakExn + val fInt64 = Unrolled64.pareduceBreakExn + val fIntInf = Unrolled64.pareduceBreakExn + end) + + structure UnrolledReducem = + Int_ChooseFromInt (struct + type 'a t = ('a * 'a -> 'a) -> 'a -> (int * int) -> (int -> 'a) -> 'a + val fInt8 = Unrolled8.reducem + val fInt16 = Unrolled16.reducem + val fInt32 = Unrolled32.reducem + val fInt64 = Unrolled64.reducem + val fIntInf = Unrolled64.reducem + end) + + structure UnrolledParform = + Int_ChooseFromInt (struct + type 'a t = (int * int) -> (int -> unit) -> unit + val fInt8 = Unrolled8.parform + val fInt16 = Unrolled16.parform + val fInt32 = Unrolled32.parform + val fInt64 = Unrolled64.parform + val fIntInf = Unrolled64.parform + end) + + structure SeqLoop = + Int_ChooseFromInt (struct + type 'a t = (int * int) -> (int -> unit) -> unit + val fInt8 = Loops8.seqLoop + val fInt16 = Loops16.seqLoop + val fInt32 = Loops32.seqLoop + val fInt64 = Loops64.seqLoop + val fIntInf = LoopsInt.seqLoop + end) + + structure SeqReduce = + Int_ChooseFromInt (struct + type 'a t = ('a * 'a -> 'a) -> 'a -> (int * int) -> (int -> 'a) -> 'a + val fInt8 = Loops8.seqReduce + val fInt16 = Loops16.seqReduce + val fInt32 = Loops32.seqReduce + val fInt64 = Loops64.seqReduce + val fIntInf = LoopsInt.seqReduce + end) + + structure UnrolledSeqLoop = + Int_ChooseFromInt (struct + type 'a t = (int * int) -> (int -> unit) -> unit + val fInt8 = Unrolled8.seqLoop + val fInt16 = Unrolled16.seqLoop + val fInt32 = Unrolled32.seqLoop + val fInt64 = Unrolled64.seqLoop + val fIntInf = Unrolled64.seqLoop + end) + + structure UnrolledSeqReduce = + Int_ChooseFromInt (struct + type 'a t = ('a * 'a -> 'a) -> 'a -> (int * int) -> (int -> 'a) -> 'a + val fInt8 = Unrolled8.seqReduce + val fInt16 = Unrolled16.seqReduce + val fInt32 = Unrolled32.seqReduce + val fInt64 = Unrolled64.seqReduce + val fIntInf = Unrolled64.seqReduce + end) + + local + + fun __inline_always__ unifiedReducem (combine: 'a * 'a -> 'a) (zero: 'a) (lo: int, hi: int) (f: int -> 'a) : 'a = + let + fun __inline_always__ regularImpl () = __inline_always__ Reducem.f combine zero (lo, hi) f + fun __inline_always__ unrolledImpl () = __inline_always__ UnrolledReducem.f combine zero (lo, hi) f + in + primSporkChoose (__inline_always__ f, __inline_always__ unrolledImpl, __inline_always__ regularImpl) + end + + fun unifiedParform (lo: int, hi: int) (f: int -> unit) : unit = + let + fun __inline_always__ regularImpl () = __inline_always__ Parform.f (lo, hi) f + + fun __inline_always__ unrolledImpl () = __inline_always__ UnrolledParform.f (lo, hi) f + in + primSporkChoose (__inline_always__ f, __inline_always__ unrolledImpl, __inline_always__ regularImpl) + end + + fun __inline_always__ unifiedPareduce (lo: int, hi: int) (zero: 'a) (step: int * 'a -> 'a) (combine: 'a * 'a -> 'a) : 'a = + let + fun __inline_always__ regularImpl () = + __inline_always__ Pareduce.f (lo, hi) zero step combine + + fun __inline_always__ unrolledImpl () = + __inline_always__ UnrolledPareduce.f (lo, hi) zero step combine + + fun __inline_always__ loopBody i = __inline_always__ step (i, zero) + in + primSporkChoose (__inline_always__ loopBody, __inline_always__ unrolledImpl, __inline_always__ regularImpl) + end + + fun __inline_always__ unifiedSeqLoop (lo: int, hi: int) (f: int -> unit) : unit = + let + fun __inline_always__ regularImpl () = __inline_always__ SeqLoop.f (lo, hi) f + fun __inline_always__ unrolledImpl () = __inline_always__ UnrolledSeqLoop.f (lo, hi) f + in + Scheduler.primLoopChoose (__inline_always__ f, __inline_always__ unrolledImpl, __inline_always__ regularImpl) + end + + fun __inline_always__ unifiedSeqReduce (combine: 'a * 'a -> 'a) (zero: 'a) (lo: int, hi: int) (f: int -> 'a) : 'a = + let + fun __inline_always__ regularImpl () = __inline_always__ SeqReduce.f combine zero (lo, hi) f + fun __inline_always__ unrolledImpl () = __inline_always__ UnrolledSeqReduce.f combine zero (lo, hi) f + in + Scheduler.primLoopChoose (__inline_always__ f, __inline_always__ unrolledImpl, __inline_always__ regularImpl) + end + in + val reducem = __inline_always__ unifiedReducem + val reduce = __inline_always__ unifiedReducem + val reducemDefault = __inline_always__ Reducem.f + val parform = __inline_always__ unifiedParform + val parformDefault = __inline_always__ Parform.f + val pareduce = __inline_always__ unifiedPareduce + val parfor = __inline_always__ ForkJoin0.parfor + val seqLoop = __inline_always__ unifiedSeqLoop + val seqReduce = __inline_always__ unifiedSeqReduce + end + + val pareduceBreakExn = __inline_always__ PareduceBreakExn.f end \ No newline at end of file diff --git a/basis-library/schedulers/spork/LoopIndex.sml b/basis-library/schedulers/spork/LoopIndex.sml new file mode 100644 index 000000000..1d09ea641 --- /dev/null +++ b/basis-library/schedulers/spork/LoopIndex.sml @@ -0,0 +1,58 @@ +signature LOOP_INDEX = +sig + type idx + type t = idx + + val fromInt: int -> idx + val toInt: idx -> int + + val increment: idx -> idx + val midpoint: idx * idx -> idx + val equal: idx * idx -> bool +end + + +functor LoopIndexFromWord(WordImpl: WORD) :> LOOP_INDEX = +struct + type idx = WordImpl.word + type t = idx + + fun __inline_always__ toInt (w: idx) = __inline_always__ WordImpl.toIntX w + fun __inline_always__ fromInt i = __inline_always__ WordImpl.fromInt i + + fun __inline_always__ midpoint (i: idx, j: idx) = + let + (* This way is broken! *) + (* val mid = WordImpl.~>> (WordImpl.+ (i, j), 0w1) *) + + val range_size = WordImpl.+ (j, WordImpl.~ i) + val mid = WordImpl.+ (i, WordImpl.div (range_size, WordImpl.fromInt 2)) + in + (* If using a different midpoint calculation, consider uncommenting + * the following for debugging/testing. + *) + + (* if toInt i <= toInt mid andalso toInt mid <= toInt j then + () + else + ( print + ( "ERROR: schedulers/spork/ForkJoin.sml: bug! midpoint failure: " + ^ Int.toString (toInt i) + ^ " " + ^ Int.toString (toInt mid) + ^ " " + ^ Int.toString (toInt j) + ^ "\n" + ) + + ; OS.Process.exit OS.Process.failure + ); *) + + mid + end + + fun __inline_always__ increment (i: idx) = + WordImpl.+ (i, fromInt 1) + + fun __inline_always__ equal (i: idx, j: idx) = (i = j) +end diff --git a/basis-library/schedulers/spork/Scheduler.sml b/basis-library/schedulers/spork/Scheduler.sml index 4d09fd2d8..099cf030c 100644 --- a/basis-library/schedulers/spork/Scheduler.sml +++ b/basis-library/schedulers/spork/Scheduler.sml @@ -103,13 +103,30 @@ struct * (exn -> 'c) (* exn seq *) * (exn * 'd -> 'c) (* exn sync *) -> 'c; + val primSporkChoose' = + _prim "spork_choose" + : ('u -> 'a) (* loop body *) + * (unit -> 'a) (* unrolled implementation *) + * (unit -> 'a) (* regular implementation *) + -> 'a; + val primLoopChoose' = + _prim "loop_choose" + : ('u -> 'a) (* loop body *) + * (unit -> 'a) (* unrolled implementation *) + * (unit -> 'a) (* regular implementation *) + -> 'a; + fun __inline_always__ primSporkFair (body, spwn, seq, sync, exnseq, exnsync) = __inline_always__ primSporkFair' (body, (), spwn, (), seq, sync, exnseq, exnsync) fun __inline_always__ primSporkKeep (body, spwn, seq, sync, exnseq, exnsync) = __inline_always__ primSporkKeep' (body, (), spwn, (), seq, sync, exnseq, exnsync) fun __inline_always__ primSporkGive (body, spwn, seq, sync, exnseq, exnsync) = __inline_always__ primSporkGive' (body, (), spwn, (), seq, sync, exnseq, exnsync) - + fun __inline_always__ primSporkChoose (loopBody, unrolled, regular) = + __inline_always__ primSporkChoose' (loopBody, unrolled, regular) + fun __inline_always__ primLoopChoose (loopBody, unrolled, regular) = + __inline_always__ primLoopChoose' (loopBody, unrolled, regular) + val primForkThreadAndSetData = _prim "spork_forkThreadAndSetData": Thread.t * 'a -> Thread.p; val primForkThreadAndSetData_youngest = _prim "spork_forkThreadAndSetData_youngest": Thread.t * 'a -> Thread.p; @@ -1024,6 +1041,7 @@ struct fun __inline_always__ tryPromoteNow yo = ( Thread.atomicBegin () ; if + (* ! Second heartbeat check *) Heartbeat.enoughToSpawn () andalso #maybeSpawn (sched_package ()) yo (Thread.current ()) then @@ -1052,6 +1070,7 @@ struct val (inject, project) = Universal.embed () fun __inline_always__ body' (): 'a = + (* ! First Hearbeat Check *) ((if not (Heartbeat.enoughToSpawn ()) then () else tryPromoteNow {youngestOptimization = true}); __inline_always__ body ()) diff --git a/basis-library/schedulers/spork/UnrolledLoops.sml b/basis-library/schedulers/spork/UnrolledLoops.sml new file mode 100644 index 000000000..460449cd4 --- /dev/null +++ b/basis-library/schedulers/spork/UnrolledLoops.sml @@ -0,0 +1,314 @@ +functor UnrolledLoops(WordImpl: WORD) :> +sig + val pareduce: (int * int) -> 'a -> (int * 'a -> 'a) -> ('a * 'a -> 'a) -> 'a + val pareduceBreakExn: (int * int) -> 'a -> (('a -> exn) * int * 'a -> 'a) -> ('a * 'a -> 'a) -> 'a + val reducem: ('a * 'a -> 'a) -> 'a -> (int * int) -> (int -> 'a) -> 'a + val parform: (int * int) -> (int -> unit) -> unit + val seqLoop: (int * int) -> (int -> unit) -> unit + val seqReduce: ('a * 'a -> 'a) -> 'a -> (int * int) -> (int -> 'a) -> 'a +end = +struct + + type word = WordImpl.word + fun __inline_always__ w2i w = __inline_always__ WordImpl.toIntX w + fun __inline_always__ i2w i = __inline_always__ WordImpl.fromInt i + + val one = i2w 1 + val two = i2w 2 + val three = i2w 3 + val four = i2w 4 + val five = i2w 5 + val six = i2w 6 + val seven = i2w 7 + val eight = i2w 8 + + fun __inline_always__ midpoint (i: word, j: word) = + WordImpl.+ (i, WordImpl.>> (WordImpl.- (j, i), 0w1)) + + + fun __inline_always__ pareduce (lo, hi) (z: 'a) (step': int * 'a -> 'a) (g: 'a * 'a -> 'a) : 'a = + let + + fun __inline_always__ step (a, i) = + __inline_always__ step' (w2i i, a) + (* __inline_always__ g (a, __inline_always__ f (w2i i)) *) + + + (* fun sequential_loop (a, i: word, j: word) = + if i < j then + sequential_loop (step (a, i), i + 0w1, j) + else a *) + + + (* fun __inline_always__ next stride = + Word64.min (Word64.<< (stride, 0w1), 0w16) *) + + + fun loop8 (a, i, j) = + if WordImpl.<= (WordImpl.+ (i, eight), j) then + let + fun __inline_never__ spwn a' = + if WordImpl.>= (WordImpl.+ (i, eight), j) then a' else + let + val mid = midpoint (WordImpl.+ (i, eight), j) + in + Scheduler.SporkJoin.spork { + tokenPolicy = Scheduler.TokenPolicyFair, + body = fn () => loop1 (a', WordImpl.+ (i, eight), mid), + spwn = fn () => loop1 (z, mid, j), + seq = fn a'' => loop1 (a'', mid, j), + sync = g, + unstolen = NONE + } + end + in + Scheduler.SporkJoin.spork { + tokenPolicy = Scheduler.TokenPolicyGive, + body = fn () => + let + val a = step (a, i) + val a = step (a, WordImpl.+ (i, one)) + val a = step (a, WordImpl.+ (i, two)) + val a = step (a, WordImpl.+ (i, three)) + val a = step (a, WordImpl.+ (i, four)) + val a = step (a, WordImpl.+ (i, five)) + val a = step (a, WordImpl.+ (i, six)) + val a = step (a, WordImpl.+ (i, seven)) + in + a + end, + seq = fn a' => loop8 (a', WordImpl.+ (i, eight), j), + sync = g, + spwn = fn () => spwn z, + unstolen = SOME spwn + } + end + else + loop1 (a, i, j) + + + + and loop4 (a, i, j) = + if WordImpl.<= (WordImpl.+ (i, four), j) then + let + fun __inline_never__ spwn a' = + if WordImpl.>= (WordImpl.+ (i, four), j) then a' else + let + val mid = midpoint (WordImpl.+ (i, four), j) + in + Scheduler.SporkJoin.spork { + tokenPolicy = Scheduler.TokenPolicyFair, + body = fn () => loop1 (a', WordImpl.+ (i, four), mid), + spwn = fn () => loop1 (z, mid, j), + seq = fn a'' => loop1 (a'', mid, j), + sync = g, + unstolen = NONE + } + end + in + Scheduler.SporkJoin.spork { + tokenPolicy = Scheduler.TokenPolicyGive, + body = fn () => + let + val a = step (a, i) + val a = step (a, WordImpl.+ (i, one)) + val a = step (a, WordImpl.+ (i, two)) + val a = step (a, WordImpl.+ (i, three)) + in + a + end, + seq = fn a' => loop8 (a', WordImpl.+ (i, four), j), + sync = g, + spwn = fn () => spwn z, + unstolen = SOME spwn + } + end + else + loop1 (a, i, j) + + + + and loop2 (a, i, j) = + if WordImpl.<= (WordImpl.+ (i, two), j) then + let + fun __inline_never__ spwn a' = + if WordImpl.>= (WordImpl.+ (i, two), j) then a' else + let + val mid = midpoint (WordImpl.+ (i, two), j) + in + Scheduler.SporkJoin.spork { + tokenPolicy = Scheduler.TokenPolicyFair, + body = fn () => loop1 (a', WordImpl.+ (i, two), mid), + spwn = fn () => loop1 (z, mid, j), + seq = fn a'' => loop1 (a'', mid, j), + sync = g, + unstolen = NONE + } + end + in + Scheduler.SporkJoin.spork { + tokenPolicy = Scheduler.TokenPolicyGive, + body = fn () => + let + val a = step (a, i) + val a = step (a, WordImpl.+ (i, one)) + in + a + end, + seq = fn a' => loop4 (a', WordImpl.+ (i, two), j), + sync = g, + spwn = fn () => spwn z, + unstolen = SOME spwn + } + end + else + loop1 (a, i, j) + + + + and loop1 (a, i, j) = + if WordImpl.<= (WordImpl.+ (i, one), j) then + let + fun __inline_never__ spwn a' = + if WordImpl.>= (WordImpl.+ (i, one), j) then a' else + let + val mid = midpoint (WordImpl.+ (i, one), j) + in + Scheduler.SporkJoin.spork { + tokenPolicy = Scheduler.TokenPolicyFair, + body = fn () => loop1 (a', WordImpl.+ (i, one), mid), + spwn = fn () => loop1 (z, mid, j), + seq = fn a'' => loop1 (a'', mid, j), + sync = g, + unstolen = NONE + } + end + in + Scheduler.SporkJoin.spork { + tokenPolicy = Scheduler.TokenPolicyGive, + body = fn () => step (a, i), + seq = fn a' => loop2 (a', WordImpl.+ (i, one), j), + sync = g, + spwn = fn () => spwn z, + unstolen = SOME spwn + } + end + else + a + + in + __inline_always__ + loop1 (z, i2w lo, i2w hi) + end + + + fun __inline_always__ pareduceBreakExn (i: int, j: int) (z: 'a) (step: ('a -> exn) * int * 'a -> 'a) (merge: 'a * 'a -> 'a): 'a = + let exception Break of 'a + fun step' (i, a) = (__inline_always__ step (Break, i, a), true) handle (Break b) => (b, false) + fun merge' ((b1, cont1), (b2, cont2)) = + if cont1 then (merge (b1, b2), cont2) else (b1, false) + + (* we can reuse the pareduce structure but need to adapt it for break semantics *) + (* for simplicity, we'll use a basic implementation that wraps pareduce *) + (* a more optimized version would inline the break logic into the unrolled loops *) + + fun continue (f : 'a -> 'a * bool) : 'a * bool -> 'a * bool = + fn (b, cont) => if cont then f b else (b, cont) + + fun iter (b: 'a) (i: word, j: word): 'a * bool = + if i = j then (b, true) else + let + fun __inline_never__ spwn b' = + if WordImpl.>= (WordImpl.+ (i, one), j) then (b', true) else + let val mid = midpoint (WordImpl.+ (i, one), j) in + Scheduler.SporkJoin.spork { + tokenPolicy = Scheduler.TokenPolicyFair, + body = fn () => iter b' (WordImpl.+ (i, one), mid), + spwn = fn () => iter z (mid, j), + seq = continue (fn b' => iter b' (mid, j)), + sync = merge', + unstolen = NONE + } + end + in + Scheduler.SporkJoin.spork { + tokenPolicy = Scheduler.TokenPolicyGive, + body = fn () => __inline_always__ step' (w2i i, b), + spwn = fn () => spwn z, + seq = continue (fn b' => iter b' (WordImpl.+ (i, one), j)), + sync = merge', + unstolen = SOME (continue spwn) + } + end + val (result, cont) = __inline_always__ iter z (i2w (Int.min (i, j)), i2w j) + in + result + end + + + fun __inline_always__ reducem g z (lo, hi) f = + pareduce (lo, hi) z (fn (i, a) => __inline_always__ g (a, __inline_always__ f i)) g + + + fun __inline_always__ parform (lo: int, hi: int) (f: int -> unit) : unit = + pareduce (lo, hi) () (fn (i, _) => f i) (fn _ => ()) + + + fun __inline_always__ seqLoop (lo: int, hi: int) (f: int -> unit) : unit = + let + fun loop8 (i: word, j: word) : unit = + if WordImpl.<= (WordImpl.+ (i, eight), j) then + let + val _ = __inline_always__ f (w2i i) + val _ = __inline_always__ f (w2i (WordImpl.+ (i, one))) + val _ = __inline_always__ f (w2i (WordImpl.+ (i, two))) + val _ = __inline_always__ f (w2i (WordImpl.+ (i, three))) + val _ = __inline_always__ f (w2i (WordImpl.+ (i, four))) + val _ = __inline_always__ f (w2i (WordImpl.+ (i, five))) + val _ = __inline_always__ f (w2i (WordImpl.+ (i, six))) + val _ = __inline_always__ f (w2i (WordImpl.+ (i, seven))) + in + loop8 (WordImpl.+ (i, eight), j) + end + else + loop1 (i, j) + + and loop1 (i: word, j: word) : unit = + if WordImpl.< (i, j) then + (__inline_always__ f (w2i i); loop1 (WordImpl.+ (i, one), j)) + else + () + in + loop8 (i2w lo, i2w hi) + end + + + fun __inline_always__ seqReduce (combine: 'a * 'a -> 'a) (zero: 'a) (lo: int, hi: int) (f: int -> 'a) : 'a = + let + fun loop8 (acc: 'a, i: word, j: word) : 'a = + if WordImpl.<= (WordImpl.+ (i, eight), j) then + let + val acc = __inline_always__ combine (acc, __inline_always__ f (w2i i)) + val acc = __inline_always__ combine (acc, __inline_always__ f (w2i (WordImpl.+ (i, one)))) + val acc = __inline_always__ combine (acc, __inline_always__ f (w2i (WordImpl.+ (i, two)))) + val acc = __inline_always__ combine (acc, __inline_always__ f (w2i (WordImpl.+ (i, three)))) + val acc = __inline_always__ combine (acc, __inline_always__ f (w2i (WordImpl.+ (i, four)))) + val acc = __inline_always__ combine (acc, __inline_always__ f (w2i (WordImpl.+ (i, five)))) + val acc = __inline_always__ combine (acc, __inline_always__ f (w2i (WordImpl.+ (i, six)))) + val acc = __inline_always__ combine (acc, __inline_always__ f (w2i (WordImpl.+ (i, seven)))) + in + loop8 (acc, WordImpl.+ (i, eight), j) + end + else + loop1 (acc, i, j) + + and loop1 (acc: 'a, i: word, j: word) : 'a = + if WordImpl.< (i, j) then + loop1 (__inline_always__ combine (acc, __inline_always__ f (w2i i)), WordImpl.+ (i, one), j) + else + acc + in + loop8 (zero, i2w lo, i2w hi) + end + +end \ No newline at end of file diff --git a/basis-library/schedulers/spork/sources.mlb b/basis-library/schedulers/spork/sources.mlb index 62d3e4520..6c380c54c 100644 --- a/basis-library/schedulers/spork/sources.mlb +++ b/basis-library/schedulers/spork/sources.mlb @@ -38,6 +38,8 @@ local in Heartbeat.sml Scheduler.sml + LoopIndex.sml + UnrolledLoops.sml end ForkJoin.sml in diff --git a/mlton/atoms/prim.fun b/mlton/atoms/prim.fun index 8e5288965..446389c44 100644 --- a/mlton/atoms/prim.fun +++ b/mlton/atoms/prim.fun @@ -113,6 +113,9 @@ datatype 'a t = | MLton_share (* to rssa (as nop or runtime C fn) *) | MLton_size (* to rssa (as runtime C fn) *) | MLton_touch (* to rssa (as nop) or backend (as nop) *) + | Loop_choose (* closure convert *) + (* Choose between unrolled and regular at compile time *) + | Spork_choose (* closure convert *) | Spork of {tokenSplitPolicy: Word32.word} (* closure convert *) | Spork_forkThreadAndSetData of {youngest: bool} (* to rssa (as runtime C fn) *) | Spork_getData of Spid.t (* backend *) @@ -294,6 +297,8 @@ fun toString (n: 'a t): string = | MLton_share => "MLton_share" | MLton_size => "MLton_size" | MLton_touch => "MLton_touch" + | Loop_choose => "loop_choose" + | Spork_choose => "spork_choose" | Spork {tokenSplitPolicy=0w0} => "spork_fair" | Spork {tokenSplitPolicy=0w1} => "spork_keep" | Spork {tokenSplitPolicy=0w2} => "spork_give" @@ -462,6 +467,9 @@ val equals: 'a t * 'a t -> bool = | (MLton_share, MLton_share) => true | (MLton_size, MLton_size) => true | (MLton_touch, MLton_touch) => true + | (Loop_choose, Loop_choose) => true + (* TODO: Check usage properly *) + | (Spork_choose, Spork_choose) => true | (Spork {tokenSplitPolicy = tsp1}, Spork {tokenSplitPolicy = tsp2}) => tsp1 = tsp2 | (Spork_forkThreadAndSetData yo1, Spork_forkThreadAndSetData yo2) => yo1 = yo2 | (Spork_getData spid, Spork_getData spid') => Spid.equals (spid, spid') @@ -648,6 +656,9 @@ val map: 'a t * ('a -> 'b) -> 'b t = | MLton_size => MLton_size | MLton_touch => MLton_touch | Spork tsp => Spork tsp + (* TODO: Check usage properly *) + | Loop_choose => Loop_choose + | Spork_choose => Spork_choose | Spork_forkThreadAndSetData z => Spork_forkThreadAndSetData z | Spork_getData spid => Spork_getData spid | Real_Math_acos z => Real_Math_acos z @@ -864,6 +875,9 @@ val kind: 'a t -> Kind.t = | MLton_size => DependsOnState | MLton_touch => SideEffect | Spork _ => SideEffect + (* TODO: Check usage properly *) + | Loop_choose => SideEffect + | Spork_choose => SideEffect | Spork_forkThreadAndSetData _ => SideEffect | Spork_getData _ => DependsOnState | Real_Math_acos _ => DependsOnState (* depends on rounding mode *) @@ -1077,6 +1091,9 @@ in Spork {tokenSplitPolicy = 0w0}, Spork {tokenSplitPolicy = 0w1}, Spork {tokenSplitPolicy = 0w2}, + (* TODO: Check usage properly *) + Loop_choose, + Spork_choose, Spork_forkThreadAndSetData {youngest=true}, Spork_forkThreadAndSetData {youngest=false}, (*Spork_getData,*) @@ -1279,6 +1296,9 @@ fun 'a checkApp (prim: 'a t, fun oneTarg f = 1 = Vector.length targs andalso done (f (targ 0)) + fun twoTargs f = + 2 = Vector.length targs + andalso done (f (targ 0, targ 1)) fun sixTargs f = 6 = Vector.length targs andalso done (f (targ 0, targ 1, targ 2, targ 3, targ 4, targ 5)) @@ -1441,6 +1461,26 @@ fun 'a checkApp (prim: 'a t, in (eightArgs (cont, taa, spwn, tba, seq, sync, exnseq, exnsync), tc) end) + | Loop_choose => + (* TODO: Check usage properly *) + twoTargs (fn (ta, tu) => + let + val loopBody = arrow (tu, ta) (* First arg: loop body function 'u -> 'a *) + val impl = arrow (unit, ta) (* Second and third args: thunks unit -> 'a *) + in + (threeArgs (loopBody, impl, impl), ta) + end) + | Spork_choose => + (* TODO: Check usage properly *) + (* spork_choose: ('u -> 'v) -> (unit -> 'a) -> (unit -> 'a) -> 'a + * where 'v = 'a, so really: ('u -> 'a) -> (unit -> 'a) -> (unit -> 'a) -> 'a *) + twoTargs (fn (ta, tu) => + let + val loopBody = arrow (tu, ta) (* First arg: loop body function 'u -> 'a *) + val impl = arrow (unit, ta) (* Second and third args: thunks unit -> 'a *) + in + (threeArgs (loopBody, impl, impl), ta) + end) | Spork_forkThreadAndSetData _ => oneTarg (fn t => (twoArgs (thread, t), thread)) | Spork_getData _ => oneTarg (fn t => (noArgs, t)) | Real_Math_acos s => realUnary s @@ -1596,6 +1636,23 @@ fun ('a, 'b) extractTargs (prim: 'b t, in six (taa, tar, tba, tbr, td, tc) end + | Loop_choose => + (* TODO: Check usage properly *) + let + val ta = result (* Result type 'a *) + val (tu, _) = deArrow (arg 0) (* First arg: loop body ('u -> 'v) *) + in + Vector.new2 (ta, tu) + end + | Spork_choose => + (* TODO: Check usage properly *) + (* spork_choose: ('u -> 'v) -> 'a -> 'a -> 'a *) + let + val ta = result (* Result type 'a *) + val (tu, _) = deArrow (arg 0) (* First arg: loop body ('u -> 'v) *) + in + Vector.new2 (ta, tu) + end | Spork_forkThreadAndSetData _ => one (arg 1) | Spork_getData _ => one result | Ref_assign _ => one (deRef (arg 0)) diff --git a/mlton/atoms/prim.sig b/mlton/atoms/prim.sig index 99cb3fcd6..5c4b47117 100644 --- a/mlton/atoms/prim.sig +++ b/mlton/atoms/prim.sig @@ -104,6 +104,8 @@ signature PRIM = | MLton_share (* to rssa (as nop or runtime C fn) *) | MLton_size (* to rssa (as runtime C fn) *) | MLton_touch (* to rssa (as nop) or backend (as nop) *) + | Loop_choose (* closure convert *) + | Spork_choose (* TODO: closure convert / SSA / SSA2 / RSSA ? *) | Spork of {tokenSplitPolicy: Word32.word} (* closure convert *) | Spork_forkThreadAndSetData of {youngest: bool} (* to rssa (as runtime C fn) *) | Spork_getData of Spid.t (* backend *) diff --git a/mlton/backend/allocate-variables.fun b/mlton/backend/allocate-variables.fun index d351977f2..e3cfb8592 100644 --- a/mlton/backend/allocate-variables.fun +++ b/mlton/backend/allocate-variables.fun @@ -293,7 +293,7 @@ structure Info = (* ------------------------------------------------- *) (* allocate *) (* ------------------------------------------------- *) - +(* TODO: Mess with isLoop stuff? *) fun allocate {function = f: Rssa.Function.t, paramOffsets: (Rssa.Var.t * Rssa.Type.t) vector -> {offset: Bytes.t, ty: Rssa.Type.t, volatile: bool} vector, sporkNesting = {maxSporkNestLength: int, diff --git a/mlton/backend/allocate-variables.sig b/mlton/backend/allocate-variables.sig index 787c068b4..960e7296d 100644 --- a/mlton/backend/allocate-variables.sig +++ b/mlton/backend/allocate-variables.sig @@ -17,7 +17,7 @@ signature ALLOCATE_VARIABLES_STRUCTS = signature ALLOCATE_VARIABLES = sig include ALLOCATE_VARIABLES_STRUCTS - + (* TODO: mess with isLoop stuff? *) val allocate: {function: Rssa.Function.t, paramOffsets: (Rssa.Var.t * Rssa.Type.t) vector -> {offset: Bytes.t, ty: Rssa.Type.t, volatile: bool} vector, diff --git a/mlton/backend/backend.fun b/mlton/backend/backend.fun index 66347f667..74f3ba52e 100644 --- a/mlton/backend/backend.fun +++ b/mlton/backend/backend.fun @@ -176,6 +176,7 @@ fun toMachine (rssa: Rssa.Program.t) = NONE => (NONE, fn _ => NONE) | SOME {sourceMaps, getFrameSourceSeqIndex} => (SOME sourceMaps, getFrameSourceSeqIndex) (* Frame info *) + (* TODO: isLoopy stuff? *) local val frameInfos: M.FrameInfo.t list ref = ref [] val nextFrameInfo = Counter.generator 0 @@ -1181,6 +1182,7 @@ fun toMachine (rssa: Rssa.Program.t) = (parallelMove {dsts = dsts', srcs = srcs'}, M.Transfer.Goto dst) end + (* TODO: loopy? *) | R.Transfer.Spork {spid, cont, spwn} => simple (M.Transfer.Spork {spid = spid, diff --git a/mlton/backend/bounce-vars.fun b/mlton/backend/bounce-vars.fun index 6c4d8af32..b3aa734e2 100644 --- a/mlton/backend/bounce-vars.fun +++ b/mlton/backend/bounce-vars.fun @@ -435,6 +435,7 @@ fun transform p = end in case kind of + (* TODO: isLoop probably not here *) Kind.SporkSpwn _ => split () | Kind.SpoinSync _ => split () | _ => Block.T {args = args, diff --git a/mlton/backend/machine.fun b/mlton/backend/machine.fun index 606069c3c..ffd189f8c 100644 --- a/mlton/backend/machine.fun +++ b/mlton/backend/machine.fun @@ -613,6 +613,7 @@ structure Transfer = return: {return: Label.t, handler: Label.t option, size: Bytes.t} option} + (* TODO: isLoop? *) | Spork of {spid: Spid.t, data: StackOffset.t, cont: Label.t, diff --git a/mlton/backend/machine.sig b/mlton/backend/machine.sig index f01abc6e4..73fef988f 100644 --- a/mlton/backend/machine.sig +++ b/mlton/backend/machine.sig @@ -185,6 +185,7 @@ signature MACHINE = handler: Label.t option (* must be kind Handler*), size: Bytes.t} option} | Goto of Label.t (* must be kind Jump *) + (* TODO: isLoopy stuff? *) | Spork of {spid: Spid.t, data: StackOffset.t, cont: Label.t, diff --git a/mlton/closure-convert/closure-convert.fun b/mlton/closure-convert/closure-convert.fun index 8b4b8a35b..d743759c7 100644 --- a/mlton/closure-convert/closure-convert.fun +++ b/mlton/closure-convert/closure-convert.fun @@ -224,6 +224,97 @@ val traceLoopBind = ("exp", SprimExp.layout exp)], Unit.layout) +(* similar to lambdaSize in polyvariance *) +fun sxmlLambdaSize (l: Slambda.t): int = + let + fun loopExp (e: Sexp.t, n: int): int = + List.fold + (Sexp.decs e, n, fn (d, n) => + case d of + Sdec.MonoVal {exp, ...} => loopPrimExp (exp, n + 1) + | Sdec.PolyVal {exp, ...} => loopExp (exp, n + 1) + | Sdec.Fun {decs, ...} => Vector.fold (decs, n, fn ({lambda, ...}, n) => + loopLambda (lambda, n)) + | Sdec.Exception _ => n + 1) + and loopLambda (l: Slambda.t, n): int = + let val m = loopExp (Slambda.body l, 0) + in m + n + end + and loopPrimExp (e: SprimExp.t, n: int): int = + case e of + SprimExp.Case {cases, default, ...} => + let + val n = n + 1 + in + Scases.fold + (cases, + (case default of + NONE => n + | SOME e => loopExp (e, n)), + loopExp) + end + | SprimExp.Handle {try, handler, ...} => + loopExp (try, loopExp (handler, n + 1)) + | SprimExp.Lambda l => loopLambda (l, n + 1) + | SprimExp.Profile _ => n + | _ => n + 1 + in + loopExp (Slambda.body l, 0) + end + +fun analyzeLambdaBody (l: Slambda.t): unit = + let + val counts = ref [] + fun addCount (name: string) = List.push (counts, name) + + fun loopExp (e: Sexp.t): unit = + List.foreach + (Sexp.decs e, fn d => + case d of + Sdec.MonoVal {exp, var, ...} => + (addCount (concat ["MonoVal(", Var.toString var, ")"]) + ; loopPrimExp exp) + | Sdec.PolyVal {exp, ...} => + (addCount "PolyVal" + ; loopExp exp) + | Sdec.Fun {decs, ...} => + (addCount (concat ["Fun(", Int.toString (Vector.length decs), " lambdas)"]) + ; Vector.foreach (decs, fn {lambda, ...} => loopLambda lambda)) + | Sdec.Exception _ => addCount "Exception") + and loopLambda (l: Slambda.t): unit = + (addCount "Lambda" + ; loopExp (Slambda.body l)) + and loopPrimExp (e: SprimExp.t): unit = + case e of + SprimExp.App {func, ...} => + addCount (concat ["App(", Layout.toString (SvarExp.layout func), ")"]) + | SprimExp.Case _ => (addCount "Case"; (* would need to recurse into cases *) ()) + | SprimExp.ConApp {con, ...} => + addCount (concat ["ConApp(", Con.toString con, ")"]) + | SprimExp.Const c => addCount (concat ["Const(", Const.toString c, ")"]) + | SprimExp.Handle _ => (addCount "Handle"; (* would need to recurse *) ()) + | SprimExp.Lambda l => loopLambda l + | SprimExp.PrimApp {prim, ...} => + addCount (concat ["PrimApp(", Prim.toString prim, ")"]) + | SprimExp.Profile _ => () + | SprimExp.Raise _ => addCount "Raise" + | SprimExp.Select {offset, ...} => + addCount (concat ["Select(#", Int.toString offset, ")"]) + | SprimExp.Tuple xs => + addCount (concat ["Tuple(", Int.toString (Vector.length xs), ")"]) + | SprimExp.Var x => addCount (concat ["Var(", Layout.toString (SvarExp.layout x), ")"]) + val _ = loopExp (Slambda.body l) + val _ = Control.diagnostics + (fn display => + List.foreach + (rev (!counts), fn name => + display (let open Layout + in seq [str " - ", str name] + end))) + in + () + end + fun closureConvert (program as Sxml.Program.T {datatypes, body}): Ssa.Program.t = let @@ -476,6 +567,30 @@ fun closureConvert in () end + | PrimApp {prim = Prim.Spork_choose, targs, args} => + (* spork_choose: ('u -> 'a) -> (unit -> 'a) -> (unit -> 'a) -> 'a + * model as applying regular to unit *) + let + fun arg i = Vector.sub (args, i) + val regular = arg 2 + val unrolled = arg 1 + val unitTy = Stype.tuple (Vector.new0 ()) + val unitArg = Value.fromType unitTy + val result = new () + val _ = Value.addHandler + (* !TODO *) + (varExp regular, fn l => + let + val lambda = Value.Lambda.dest l + val {arg = formal, body, ...} = Lambda.dest lambda + in + Value.coerce {from = unitArg, to = value formal} + ; Value.coerce {from = expValue body, to = result} + (* ! body is not really needed for result here *) + end) + in + () + end | PrimApp {prim, args, ...} => set (Value.primApply {prim = prim, args = varExps args, @@ -812,7 +927,7 @@ fun closureConvert cases = Dexp.Con (Vector.map - (cons, fn {lambda, con} => + (cons, fn {lambda, con} => let val {arg = param, body, ...} = Slambda.dest lambda val info as LambdaInfo.T {name, ...} = lambdaInfo lambda @@ -1255,6 +1370,194 @@ fun closureConvert in (exp, ac) end + | SprimExp.PrimApp {prim = Prim.Spork_choose, targs, args} => + (* spork_choose: ('u -> 'a) -> (unit -> 'a) -> (unit -> 'a) -> 'a *) + let + fun arg i = Vector.sub (args, i) + val loopBodyVar = arg 0 + val unrolled = arg 1 + val regular = arg 2 + + val loopBodyInfo = varExpInfo loopBodyVar + val loopBodySize = + let + val v = #value loopBodyInfo + fun extractLambda (v: Value.t): Slambda.t option = + case Value.dest v of + Value.Lambdas ls => + (case Lambdas.toList ls of + l :: _ => SOME (Value.Lambda.dest l) + | [] => NONE) + | _ => NONE + in + case extractLambda v of + SOME l => + let + val size = sxmlLambdaSize l + val _ = Control.diagnostics + (fn display => + display (let open Layout + in seq [str "spork_choose loop body size: ", + Int.layout size] + end)) + val _ = analyzeLambdaBody l + in + size + end + | NONE => + let + val _ = Control.diagnostics + (fn display => + display (let open Layout + in seq [str "spork_choose: no lambda found for ", + SvarExp.layout loopBodyVar] + end)) + in + 0 + end + end + + (* choose between unrolled and regular based on loop body size *) + val threshold = !Control.sporkChooseThreshold + val chosenImpl = if loopBodySize <= threshold then unrolled else regular + val _ = Control.diagnostics + (fn display => + display (let open Layout + in seq [str " Decision: using ", + str (if loopBodySize <= threshold then "UNROLLED" else "REGULAR"), + str " (threshold=", Int.layout threshold, str ")"] + end)) + + (* apply the chosen implementation to unit *) + val func = varExpInfo chosenImpl + val funcVal = VarInfo.value func + (* unit value *) + val unitTy = Type.tuple (Vector.new0 ()) (* TODO: refactor as SType.unit *) + val unitExp = Dexp.tuple {exps = Vector.new0 (), ty = unitTy} (* ! Dexp.unit = val unit = Tuple {exps = Vector.new0 (), ty = Type.unit} *) + val unitVal = Value.fromType (Stype.tuple (Vector.new0 ())) + val {cons, ...} = valueLambdasInfo funcVal + in + (* ! heavy code duplication from apply *) + (Dexp.casee + {test = convertVarInfo func, + ty = ty, + default = NONE, + cases = + Dexp.Con + (Vector.map + (cons, fn {lambda, con} => + let + val {arg = param, body, ...} = Slambda.dest lambda + val info as LambdaInfo.T {name, ...} = lambdaInfo lambda + val result = expValue body + val env = (Var.newString "env", lambdaInfoType info) + in {con = con, + args = Vector.new1 env, + body = coerce (Dexp.call + {func = name, + args = Vector.new2 (Dexp.var env, + coerce (unitExp, unitVal, + value param)), + inline = InlineAttr.Auto, + ty = valueType result}, + result, v)} + end))}, + ac) + end + | SprimExp.PrimApp {prim = Prim.Loop_choose, targs, args} => + (* loop_choose: ('u -> 'a) -> (unit -> 'a) -> (unit -> 'a) -> 'a *) + let + fun arg i = Vector.sub (args, i) + val loopBodyVar = arg 0 + val unrolled = arg 1 + val regular = arg 2 + + val loopBodyInfo = varExpInfo loopBodyVar + val loopBodySize = + let + val v = #value loopBodyInfo + fun extractLambda (v: Value.t): Slambda.t option = + case Value.dest v of + Value.Lambdas ls => + (case Lambdas.toList ls of + l :: _ => SOME (Value.Lambda.dest l) + | [] => NONE) + | _ => NONE + in + case extractLambda v of + SOME l => + let + val size = sxmlLambdaSize l + val _ = Control.diagnostics + (fn display => + display (let open Layout + in seq [str "loop_choose loop body size: ", + Int.layout size] + end)) + val _ = analyzeLambdaBody l + in + size + end + | NONE => + let + val _ = Control.diagnostics + (fn display => + display (let open Layout + in seq [str "loop_choose: no lambda found for ", + SvarExp.layout loopBodyVar] + end)) + in + 0 + end + end + + (* choose between unrolled and regular based on loop body size *) + val threshold = 100 (* hardcoded for now *) + val chosenImpl = if loopBodySize <= threshold then unrolled else regular + val _ = Control.diagnostics + (fn display => + display (let open Layout + in seq [str " Decision: using ", + str (if loopBodySize <= threshold then "UNROLLED" else "REGULAR"), + str " (threshold=100)"] + end)) + + (* apply the chosen implementation to unit *) + val func = varExpInfo chosenImpl + val funcVal = VarInfo.value func + (* unit value *) + val unitTy = Type.tuple (Vector.new0 ()) + val unitExp = Dexp.tuple {exps = Vector.new0 (), ty = unitTy} + val unitVal = Value.fromType (Stype.tuple (Vector.new0 ())) + val {cons, ...} = valueLambdasInfo funcVal + in + (* ! heavy code duplication from apply *) + (Dexp.casee + {test = convertVarInfo func, + ty = ty, + default = NONE, + cases = + Dexp.Con + (Vector.map + (cons, fn {lambda, con} => + let + val {arg = param, body, ...} = Slambda.dest lambda + val info as LambdaInfo.T {name, ...} = lambdaInfo lambda + val result = expValue body + val env = (Var.newString "env", lambdaInfoType info) + in {con = con, + args = Vector.new1 env, + body = coerce (Dexp.call + {func = name, + args = Vector.new2 (Dexp.var env, + coerce (unitExp, unitVal, + value param)), + inline = InlineAttr.Auto, + ty = valueType result}, + result, v)} + end))}, + ac) + end | SprimExp.PrimApp {prim, targs, args} => let val prim = Prim.map (prim, convertType) @@ -1340,6 +1643,7 @@ fun closureConvert v1 (coerce (convertVarInfo y, VarInfo.value y, v))) end + (* TODO: loopy stuff? *) | Prim.Spork_forkThreadAndSetData _ => let val t = varExpInfo (arg 0) @@ -1490,4 +1794,4 @@ fun closureConvert program end -end +end \ No newline at end of file diff --git a/mlton/codegen/c-codegen/c-codegen.fun b/mlton/codegen/c-codegen/c-codegen.fun index e02ad16a7..2db20e83b 100644 --- a/mlton/codegen/c-codegen/c-codegen.fun +++ b/mlton/codegen/c-codegen/c-codegen.fun @@ -145,6 +145,7 @@ structure Operand = | _ => false end +(* # NOTE: Safe to inline in C *) fun implementsPrim (p: 'a Prim.t): bool = let datatype z = datatype Prim.t @@ -1512,6 +1513,7 @@ fun output {program as Machine.Program.T {chunks, frameInfos, main, ...}, )) ; jump label) | Goto dst => gotoLabel (dst, {tab = true}) + (* TODO: does loopiness affect this section? *) | Spork {cont, ...} => gotoLabel (cont, {tab = true}) | Raise {raisesTo} => (outputStatement (Statement.PrimApp diff --git a/mlton/control/control-flags.sig b/mlton/control/control-flags.sig index 8b0404160..5d23804b2 100644 --- a/mlton/control/control-flags.sig +++ b/mlton/control/control-flags.sig @@ -69,6 +69,7 @@ signature CONTROL_FLAGS = val closureConvertGlobalize: bool ref val closureConvertShrink: bool ref + val sporkChooseThreshold: int ref structure Codegen: sig diff --git a/mlton/control/control-flags.sml b/mlton/control/control-flags.sml index d4f002de4..b30403b10 100644 --- a/mlton/control/control-flags.sml +++ b/mlton/control/control-flags.sml @@ -230,6 +230,11 @@ val closureConvertShrink = control {name = "closureConvertShrink", default = true, toString = Bool.toString} +val sporkChooseThreshold: int ref = + control {name = "spork choose threshold", + default = 100, + toString = Int.toString} + structure Codegen = struct datatype t = diff --git a/mlton/main/main.fun b/mlton/main/main.fun index 9e43abe2d..6a5777187 100644 --- a/mlton/main/main.fun +++ b/mlton/main/main.fun @@ -290,6 +290,12 @@ fun makeOptions {usage} = (Expert, "closure-convert-shrink", " {true|false}", "whether to shrink during closure conversion", Bool (fn b => (closureConvertShrink := b))), + (Expert, "spork-choose-threshold", " ", + "threshold for spork_choose loop body size decision, default = 100", + Int (fn n => + if n < 0 + then usage "spork-choose-threshold must be non-negative" + else sporkChooseThreshold := n)), (Normal, "codegen", concat [" {", String.concatWith diff --git a/mlton/ssa/analyze.fun b/mlton/ssa/analyze.fun index 37f53cc9d..d8b8d760c 100644 --- a/mlton/ssa/analyze.fun +++ b/mlton/ssa/analyze.fun @@ -149,6 +149,7 @@ fun 'a analyze in () end | Goto {dst, args} => coerces ("goto", values args, labelValues dst) + (* TODO: loopy stuff? *) | Spork {cont, spwn, ...} => let fun ensureNullary j = diff --git a/mlton/ssa/analyze2.fun b/mlton/ssa/analyze2.fun index dea9e994a..1246e2d7a 100644 --- a/mlton/ssa/analyze2.fun +++ b/mlton/ssa/analyze2.fun @@ -161,6 +161,7 @@ fun 'a analyze in () end | Goto {dst, args} => coerces ("goto", values args, labelValues dst) + (* TODO: loopy stuff? *) | Spork {cont, spwn, ...} => let fun ensureNullary j = diff --git a/mlton/ssa/common-arg.fun b/mlton/ssa/common-arg.fun index 786840019..71d1cec83 100644 --- a/mlton/ssa/common-arg.fun +++ b/mlton/ssa/common-arg.fun @@ -131,6 +131,7 @@ fun transform (Program.T {datatypes, globals, functions, main}) = (Cases.foreach (cases, visitLabelArgs) ; Option.app (default, visitLabelArgs)) | Goto {dst, args} => flowVarsLabelArgs (args, dst) + (* TODO: loopy stuff? *) | Spork {spid, cont, spwn} => (visitLabelArgs cont; visitLabelArgs spwn) | Spoin {spid, seq, sync} => diff --git a/mlton/ssa/direct-exp.fun b/mlton/ssa/direct-exp.fun index 92d6daac2..e816f394f 100644 --- a/mlton/ssa/direct-exp.fun +++ b/mlton/ssa/direct-exp.fun @@ -43,6 +43,7 @@ datatype t = | Let of {decs: {var: Var.t, exp: t} list, body: t} | Name of t * (Var.t -> t) + (* TODO: isLoop *) | Spork of {spid: Spid.t, cont: t, spwn: t, diff --git a/mlton/ssa/direct-exp.sig b/mlton/ssa/direct-exp.sig index c5a430a72..38e0d65f9 100644 --- a/mlton/ssa/direct-exp.sig +++ b/mlton/ssa/direct-exp.sig @@ -57,6 +57,7 @@ signature DIRECT_EXP = val linearizeGoto: t * Return.Handler.t * Label.t -> Label.t * Block.t list val name: t * (Var.t -> t) -> t + (* TODO: I don't think this needs to be messed with *) val spork: {spid: Spid.t, cont: t, spwn: t, ty: Type.t} -> t val spoin: {spid: Spid.t, seq: t, sync: t, ty: Type.t} -> t val primApp: {args: t vector, diff --git a/sources.mlb b/sources.mlb new file mode 100644 index 000000000..edab4ba3b --- /dev/null +++ b/sources.mlb @@ -0,0 +1,7 @@ +basis-library/basis.mlb +basis-library/mpl.mlb +basis-library/schedulers/spork/sources.mlb +mlton/sources.mlb +mlton/codegen/c-codegen/sources.mlb +mlton/codegen/c-codegen/c-codegen.sig +mlton/codegen/c-codegen/c-codegen.fun \ No newline at end of file