From 073d0ad9ba69c7738b6f332cd3e3719c4e91bbed Mon Sep 17 00:00:00 2001 From: Sundara Vishnu Satish Date: Tue, 23 Sep 2025 16:52:42 -0400 Subject: [PATCH 01/14] chore: some comments on sporky locations --- .gitignore | 1 + basis-library/schedulers/spork/Scheduler.sml | 2 ++ mlton/atoms/prim.fun | 8 ++++++++ mlton/atoms/prim.sig | 1 + mlton/backend/allocate-variables.fun | 2 +- mlton/backend/allocate-variables.sig | 2 +- mlton/backend/backend.fun | 2 ++ mlton/backend/bounce-vars.fun | 1 + mlton/backend/machine.fun | 1 + mlton/backend/machine.sig | 1 + mlton/closure-convert/closure-convert.fun | 3 +++ mlton/codegen/c-codegen/c-codegen.fun | 2 ++ mlton/ssa/analyze.fun | 1 + mlton/ssa/analyze2.fun | 1 + mlton/ssa/common-arg.fun | 1 + mlton/ssa/direct-exp.fun | 1 + mlton/ssa/direct-exp.sig | 1 + sources.mlb | 7 +++++++ 18 files changed, 36 insertions(+), 2 deletions(-) create mode 100644 sources.mlb diff --git a/.gitignore b/.gitignore index 2fef60161..36e5ccafd 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ /build/ /install/ mlton-*.tgz +.vscode/ diff --git a/basis-library/schedulers/spork/Scheduler.sml b/basis-library/schedulers/spork/Scheduler.sml index 4d09fd2d8..7099b5189 100644 --- a/basis-library/schedulers/spork/Scheduler.sml +++ b/basis-library/schedulers/spork/Scheduler.sml @@ -1024,6 +1024,7 @@ struct fun __inline_always__ tryPromoteNow yo = ( Thread.atomicBegin () ; if + (* ! Second heartbeat check *) Heartbeat.enoughToSpawn () andalso #maybeSpawn (sched_package ()) yo (Thread.current ()) then @@ -1052,6 +1053,7 @@ struct val (inject, project) = Universal.embed () fun __inline_always__ body' (): 'a = + (* ! First Hearbeat Check *) ((if not (Heartbeat.enoughToSpawn ()) then () else tryPromoteNow {youngestOptimization = true}); __inline_always__ body ()) diff --git a/mlton/atoms/prim.fun b/mlton/atoms/prim.fun index 8e5288965..7ed2b0a51 100644 --- a/mlton/atoms/prim.fun +++ b/mlton/atoms/prim.fun @@ -113,6 +113,7 @@ datatype 'a t = | MLton_share (* to rssa (as nop or runtime C fn) *) | MLton_size (* to rssa (as runtime C fn) *) | MLton_touch (* to rssa (as nop) or backend (as nop) *) + (* TODO: add isLoop: bool*) | Spork of {tokenSplitPolicy: Word32.word} (* closure convert *) | Spork_forkThreadAndSetData of {youngest: bool} (* to rssa (as runtime C fn) *) | Spork_getData of Spid.t (* backend *) @@ -294,6 +295,7 @@ fun toString (n: 'a t): string = | MLton_share => "MLton_share" | MLton_size => "MLton_size" | MLton_touch => "MLton_touch" + (* TODO: spork {tokenSplitPolicy, isLoop} *) | Spork {tokenSplitPolicy=0w0} => "spork_fair" | Spork {tokenSplitPolicy=0w1} => "spork_keep" | Spork {tokenSplitPolicy=0w2} => "spork_give" @@ -462,6 +464,7 @@ val equals: 'a t * 'a t -> bool = | (MLton_share, MLton_share) => true | (MLton_size, MLton_size) => true | (MLton_touch, MLton_touch) => true + (* TODO: isLoop1 = isLoop2 *) | (Spork {tokenSplitPolicy = tsp1}, Spork {tokenSplitPolicy = tsp2}) => tsp1 = tsp2 | (Spork_forkThreadAndSetData yo1, Spork_forkThreadAndSetData yo2) => yo1 = yo2 | (Spork_getData spid, Spork_getData spid') => Spid.equals (spid, spid') @@ -647,6 +650,7 @@ val map: 'a t * ('a -> 'b) -> 'b t = | MLton_share => MLton_share | MLton_size => MLton_size | MLton_touch => MLton_touch + (* TODO: spork tsp isLoop *) | Spork tsp => Spork tsp | Spork_forkThreadAndSetData z => Spork_forkThreadAndSetData z | Spork_getData spid => Spork_getData spid @@ -863,6 +867,7 @@ val kind: 'a t -> Kind.t = | MLton_share => SideEffect | MLton_size => DependsOnState | MLton_touch => SideEffect + (* TODO *) | Spork _ => SideEffect | Spork_forkThreadAndSetData _ => SideEffect | Spork_getData _ => DependsOnState @@ -1074,6 +1079,7 @@ in MLton_share, MLton_size, MLton_touch, + (* TODO: Spork {tokenSplitPolicy = ..., isLoop = ...} *) Spork {tokenSplitPolicy = 0w0}, Spork {tokenSplitPolicy = 0w1}, Spork {tokenSplitPolicy = 0w2}, @@ -1428,6 +1434,7 @@ fun 'a checkApp (prim: 'a t, | MLton_size => oneTarg (fn t => (oneArg t, csize)) | MLton_touch => oneTarg (fn t => (oneArg t, unit)) | Spork _ => + (* TODO: Add isLoop arg -> sevenTargs and nineTargs? *) (* spork: ('aa -> 'ar) * 'aa * ('ba * 'd -> 'br) * 'ba * ('ar -> 'c) * ('ar * 'd -> 'c) * (exn -> 'c) * (exn * 'd -> 'c) -> 'c *) (* ('aa -> 'ar) * 'aa * ('ba * 'd -> 'br) * 'ba * ('ar -> 'c) * ('ar * 'd -> 'c) * (exn -> 'c) * (exn * 'd -> 'c) -> 'c *) sixTargs (fn (taa, tar, tba, tbr, td, tc) => @@ -1585,6 +1592,7 @@ fun ('a, 'b) extractTargs (prim: 'b t, | MLton_size => one (arg 0) | MLton_touch => one (arg 0) | Spork _ => + (* TODO: add isLoop *) (* spork: ('aa -> 'ar) * 'aa * ('ba * 'd -> 'br) * 'ba * ('ar -> 'c) * ('ar * 'd -> 'c) * (exn -> 'c) * (exn * 'd -> 'c) -> 'c *) let val (taa, tar) = deArrow (arg 0) diff --git a/mlton/atoms/prim.sig b/mlton/atoms/prim.sig index 99cb3fcd6..045636155 100644 --- a/mlton/atoms/prim.sig +++ b/mlton/atoms/prim.sig @@ -104,6 +104,7 @@ signature PRIM = | MLton_share (* to rssa (as nop or runtime C fn) *) | MLton_size (* to rssa (as runtime C fn) *) | MLton_touch (* to rssa (as nop) or backend (as nop) *) + (* TODO: add isLoop bool *) | Spork of {tokenSplitPolicy: Word32.word} (* closure convert *) | Spork_forkThreadAndSetData of {youngest: bool} (* to rssa (as runtime C fn) *) | Spork_getData of Spid.t (* backend *) diff --git a/mlton/backend/allocate-variables.fun b/mlton/backend/allocate-variables.fun index d351977f2..e3cfb8592 100644 --- a/mlton/backend/allocate-variables.fun +++ b/mlton/backend/allocate-variables.fun @@ -293,7 +293,7 @@ structure Info = (* ------------------------------------------------- *) (* allocate *) (* ------------------------------------------------- *) - +(* TODO: Mess with isLoop stuff? *) fun allocate {function = f: Rssa.Function.t, paramOffsets: (Rssa.Var.t * Rssa.Type.t) vector -> {offset: Bytes.t, ty: Rssa.Type.t, volatile: bool} vector, sporkNesting = {maxSporkNestLength: int, diff --git a/mlton/backend/allocate-variables.sig b/mlton/backend/allocate-variables.sig index 787c068b4..960e7296d 100644 --- a/mlton/backend/allocate-variables.sig +++ b/mlton/backend/allocate-variables.sig @@ -17,7 +17,7 @@ signature ALLOCATE_VARIABLES_STRUCTS = signature ALLOCATE_VARIABLES = sig include ALLOCATE_VARIABLES_STRUCTS - + (* TODO: mess with isLoop stuff? *) val allocate: {function: Rssa.Function.t, paramOffsets: (Rssa.Var.t * Rssa.Type.t) vector -> {offset: Bytes.t, ty: Rssa.Type.t, volatile: bool} vector, diff --git a/mlton/backend/backend.fun b/mlton/backend/backend.fun index 66347f667..74f3ba52e 100644 --- a/mlton/backend/backend.fun +++ b/mlton/backend/backend.fun @@ -176,6 +176,7 @@ fun toMachine (rssa: Rssa.Program.t) = NONE => (NONE, fn _ => NONE) | SOME {sourceMaps, getFrameSourceSeqIndex} => (SOME sourceMaps, getFrameSourceSeqIndex) (* Frame info *) + (* TODO: isLoopy stuff? *) local val frameInfos: M.FrameInfo.t list ref = ref [] val nextFrameInfo = Counter.generator 0 @@ -1181,6 +1182,7 @@ fun toMachine (rssa: Rssa.Program.t) = (parallelMove {dsts = dsts', srcs = srcs'}, M.Transfer.Goto dst) end + (* TODO: loopy? *) | R.Transfer.Spork {spid, cont, spwn} => simple (M.Transfer.Spork {spid = spid, diff --git a/mlton/backend/bounce-vars.fun b/mlton/backend/bounce-vars.fun index 6c4d8af32..b3aa734e2 100644 --- a/mlton/backend/bounce-vars.fun +++ b/mlton/backend/bounce-vars.fun @@ -435,6 +435,7 @@ fun transform p = end in case kind of + (* TODO: isLoop probably not here *) Kind.SporkSpwn _ => split () | Kind.SpoinSync _ => split () | _ => Block.T {args = args, diff --git a/mlton/backend/machine.fun b/mlton/backend/machine.fun index 606069c3c..ffd189f8c 100644 --- a/mlton/backend/machine.fun +++ b/mlton/backend/machine.fun @@ -613,6 +613,7 @@ structure Transfer = return: {return: Label.t, handler: Label.t option, size: Bytes.t} option} + (* TODO: isLoop? *) | Spork of {spid: Spid.t, data: StackOffset.t, cont: Label.t, diff --git a/mlton/backend/machine.sig b/mlton/backend/machine.sig index f01abc6e4..73fef988f 100644 --- a/mlton/backend/machine.sig +++ b/mlton/backend/machine.sig @@ -185,6 +185,7 @@ signature MACHINE = handler: Label.t option (* must be kind Handler*), size: Bytes.t} option} | Goto of Label.t (* must be kind Jump *) + (* TODO: isLoopy stuff? *) | Spork of {spid: Spid.t, data: StackOffset.t, cont: Label.t, diff --git a/mlton/closure-convert/closure-convert.fun b/mlton/closure-convert/closure-convert.fun index 8b4b8a35b..f3afa3546 100644 --- a/mlton/closure-convert/closure-convert.fun +++ b/mlton/closure-convert/closure-convert.fun @@ -158,6 +158,7 @@ val convertPrimExpInfo = Trace.info "ClosureConvert.convertPrimExp" val valueTypeInfo = Trace.info "ClosureConvert.valueType" structure LambdaFree = LambdaFree (Sxml) +(* # NOTE: Free variable analysis *) local open LambdaFree @@ -1101,6 +1102,7 @@ fun closureConvert args = Vector.new1 (lambdaInfoTuple info)}, ac) end + (* TODO: isLoop bool *) | SprimExp.PrimApp {prim = Prim.Spork {tokenSplitPolicy}, targs, args} => (* spork: ('aa -> 'ar) * 'aa * ('ba * 'd -> 'br) * 'bb * ('ar -> 'c) * ('ar * 'd -> 'c) -> 'c *) let @@ -1340,6 +1342,7 @@ fun closureConvert v1 (coerce (convertVarInfo y, VarInfo.value y, v))) end + (* TODO: loopy stuff? *) | Prim.Spork_forkThreadAndSetData _ => let val t = varExpInfo (arg 0) diff --git a/mlton/codegen/c-codegen/c-codegen.fun b/mlton/codegen/c-codegen/c-codegen.fun index e02ad16a7..2db20e83b 100644 --- a/mlton/codegen/c-codegen/c-codegen.fun +++ b/mlton/codegen/c-codegen/c-codegen.fun @@ -145,6 +145,7 @@ structure Operand = | _ => false end +(* # NOTE: Safe to inline in C *) fun implementsPrim (p: 'a Prim.t): bool = let datatype z = datatype Prim.t @@ -1512,6 +1513,7 @@ fun output {program as Machine.Program.T {chunks, frameInfos, main, ...}, )) ; jump label) | Goto dst => gotoLabel (dst, {tab = true}) + (* TODO: does loopiness affect this section? *) | Spork {cont, ...} => gotoLabel (cont, {tab = true}) | Raise {raisesTo} => (outputStatement (Statement.PrimApp diff --git a/mlton/ssa/analyze.fun b/mlton/ssa/analyze.fun index 37f53cc9d..d8b8d760c 100644 --- a/mlton/ssa/analyze.fun +++ b/mlton/ssa/analyze.fun @@ -149,6 +149,7 @@ fun 'a analyze in () end | Goto {dst, args} => coerces ("goto", values args, labelValues dst) + (* TODO: loopy stuff? *) | Spork {cont, spwn, ...} => let fun ensureNullary j = diff --git a/mlton/ssa/analyze2.fun b/mlton/ssa/analyze2.fun index dea9e994a..1246e2d7a 100644 --- a/mlton/ssa/analyze2.fun +++ b/mlton/ssa/analyze2.fun @@ -161,6 +161,7 @@ fun 'a analyze in () end | Goto {dst, args} => coerces ("goto", values args, labelValues dst) + (* TODO: loopy stuff? *) | Spork {cont, spwn, ...} => let fun ensureNullary j = diff --git a/mlton/ssa/common-arg.fun b/mlton/ssa/common-arg.fun index 786840019..71d1cec83 100644 --- a/mlton/ssa/common-arg.fun +++ b/mlton/ssa/common-arg.fun @@ -131,6 +131,7 @@ fun transform (Program.T {datatypes, globals, functions, main}) = (Cases.foreach (cases, visitLabelArgs) ; Option.app (default, visitLabelArgs)) | Goto {dst, args} => flowVarsLabelArgs (args, dst) + (* TODO: loopy stuff? *) | Spork {spid, cont, spwn} => (visitLabelArgs cont; visitLabelArgs spwn) | Spoin {spid, seq, sync} => diff --git a/mlton/ssa/direct-exp.fun b/mlton/ssa/direct-exp.fun index 92d6daac2..e816f394f 100644 --- a/mlton/ssa/direct-exp.fun +++ b/mlton/ssa/direct-exp.fun @@ -43,6 +43,7 @@ datatype t = | Let of {decs: {var: Var.t, exp: t} list, body: t} | Name of t * (Var.t -> t) + (* TODO: isLoop *) | Spork of {spid: Spid.t, cont: t, spwn: t, diff --git a/mlton/ssa/direct-exp.sig b/mlton/ssa/direct-exp.sig index c5a430a72..38e0d65f9 100644 --- a/mlton/ssa/direct-exp.sig +++ b/mlton/ssa/direct-exp.sig @@ -57,6 +57,7 @@ signature DIRECT_EXP = val linearizeGoto: t * Return.Handler.t * Label.t -> Label.t * Block.t list val name: t * (Var.t -> t) -> t + (* TODO: I don't think this needs to be messed with *) val spork: {spid: Spid.t, cont: t, spwn: t, ty: Type.t} -> t val spoin: {spid: Spid.t, seq: t, sync: t, ty: Type.t} -> t val primApp: {args: t vector, diff --git a/sources.mlb b/sources.mlb new file mode 100644 index 000000000..edab4ba3b --- /dev/null +++ b/sources.mlb @@ -0,0 +1,7 @@ +basis-library/basis.mlb +basis-library/mpl.mlb +basis-library/schedulers/spork/sources.mlb +mlton/sources.mlb +mlton/codegen/c-codegen/sources.mlb +mlton/codegen/c-codegen/c-codegen.sig +mlton/codegen/c-codegen/c-codegen.fun \ No newline at end of file From 6343044bdb71ebb19b99240c25e02745c3c50085 Mon Sep 17 00:00:00 2001 From: Sundara Vishnu Satish Date: Wed, 24 Sep 2025 16:42:06 -0400 Subject: [PATCH 02/14] =?UTF-8?q?chore:=20copy=20in=20unrolled=20loops,=20?= =?UTF-8?q?name=20new=20prim=C2=BF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- basis-library/schedulers/spork/Scheduler.sml | 7 +- .../schedulers/spork/UnrolledLoops.sml | 187 ++++++++++++++++++ basis-library/schedulers/spork/sources.mlb | 1 + mlton/atoms/prim.fun | 6 +- 4 files changed, 198 insertions(+), 3 deletions(-) create mode 100644 basis-library/schedulers/spork/UnrolledLoops.sml diff --git a/basis-library/schedulers/spork/Scheduler.sml b/basis-library/schedulers/spork/Scheduler.sml index 7099b5189..7023fd9f9 100644 --- a/basis-library/schedulers/spork/Scheduler.sml +++ b/basis-library/schedulers/spork/Scheduler.sml @@ -103,13 +103,18 @@ struct * (exn -> 'c) (* exn seq *) * (exn * 'd -> 'c) (* exn sync *) -> 'c; + val primSporkChoose' = + _prim "spork_choose" + : ('u -> 'v) -> 'a -> 'a -> 'a + fun __inline_always__ primSporkFair (body, spwn, seq, sync, exnseq, exnsync) = __inline_always__ primSporkFair' (body, (), spwn, (), seq, sync, exnseq, exnsync) fun __inline_always__ primSporkKeep (body, spwn, seq, sync, exnseq, exnsync) = __inline_always__ primSporkKeep' (body, (), spwn, (), seq, sync, exnseq, exnsync) fun __inline_always__ primSporkGive (body, spwn, seq, sync, exnseq, exnsync) = __inline_always__ primSporkGive' (body, (), spwn, (), seq, sync, exnseq, exnsync) - + (* TODO: add primSporkChoose *) + val primForkThreadAndSetData = _prim "spork_forkThreadAndSetData": Thread.t * 'a -> Thread.p; val primForkThreadAndSetData_youngest = _prim "spork_forkThreadAndSetData_youngest": Thread.t * 'a -> Thread.p; diff --git a/basis-library/schedulers/spork/UnrolledLoops.sml b/basis-library/schedulers/spork/UnrolledLoops.sml new file mode 100644 index 000000000..398185699 --- /dev/null +++ b/basis-library/schedulers/spork/UnrolledLoops.sml @@ -0,0 +1,187 @@ +structure Parfor: PARFOR = +struct + + type word = Word64.word + fun __inline_always__ w2i w = __inline_always__ Word64.toIntX w + fun __inline_always__ i2w i = __inline_always__ Word64.fromInt i + + fun __inline_always__ midpoint (i: word, j: word) = + i + (Word64.>> (j - i, 0w1)) + + + fun __inline_always__ pareduce (lo, hi) (z: 'a) (step': int * 'a -> 'a) (g: 'a * 'a -> 'a) : 'a = + let + + fun __inline_always__ step (a, i) = + __inline_always__ step' (w2i i, a) + (* __inline_always__ g (a, __inline_always__ f (w2i i)) *) + + + (* fun sequential_loop (a, i: word, j: word) = + if i < j then + sequential_loop (step (a, i), i + 0w1, j) + else a *) + + + (* fun __inline_always__ next stride = + Word64.min (Word64.<< (stride, 0w1), 0w16) *) + + + fun loop8 (a, i, j) = + if i + 0w8 <= j then + let + fun __inline_never__ spwn a' = + if i + 0w8 >= j then a' else + let + val mid = midpoint (i + 0w8, j) + in + ForkJoin.spork { + tokenPolicy = ForkJoin.TokenPolicyFair, + body = fn () => loop1 (a', i + 0w8, mid), + spwn = fn () => loop1 (z, mid, j), + seq = fn a'' => loop1 (a'', mid, j), + sync = g, + unstolen = NONE + } + end + in + ForkJoin.spork { + tokenPolicy = ForkJoin.TokenPolicyGive, + body = fn () => + let + val a = step (a, i) + val a = step (a, i+0w1) + val a = step (a, i+0w2) + val a = step (a, i+0w3) + val a = step (a, i+0w4) + val a = step (a, i+0w5) + val a = step (a, i+0w6) + val a = step (a, i+0w7) + in + a + end, + seq = fn a' => loop8 (a', i + 0w8, j), + sync = g, + spwn = fn () => spwn z, + unstolen = SOME spwn + } + end + else + loop1 (a, i, j) + + + + and loop4 (a, i, j) = + if i + 0w4 <= j then + let + fun __inline_never__ spwn a' = + if i + 0w4 >= j then a' else + let + val mid = midpoint (i + 0w4, j) + in + ForkJoin.spork { + tokenPolicy = ForkJoin.TokenPolicyFair, + body = fn () => loop1 (a', i + 0w4, mid), + spwn = fn () => loop1 (z, mid, j), + seq = fn a'' => loop1 (a'', mid, j), + sync = g, + unstolen = NONE + } + end + in + ForkJoin.spork { + tokenPolicy = ForkJoin.TokenPolicyGive, + body = fn () => + let + val a = step (a, i) + val a = step (a, i+0w1) + val a = step (a, i+0w2) + val a = step (a, i+0w3) + in + a + end, + seq = fn a' => loop8 (a', i + 0w4, j), + sync = g, + spwn = fn () => spwn z, + unstolen = SOME spwn + } + end + else + loop1 (a, i, j) + + + + and loop2 (a, i, j) = + if i + 0w2 <= j then + let + fun __inline_never__ spwn a' = + if i + 0w2 >= j then a' else + let + val mid = midpoint (i + 0w2, j) + in + ForkJoin.spork { + tokenPolicy = ForkJoin.TokenPolicyFair, + body = fn () => loop1 (a', i + 0w2, mid), + spwn = fn () => loop1 (z, mid, j), + seq = fn a'' => loop1 (a'', mid, j), + sync = g, + unstolen = NONE + } + end + in + ForkJoin.spork { + tokenPolicy = ForkJoin.TokenPolicyGive, + body = fn () => + let + val a = step (a, i) + val a = step (a, i+0w1) + in + a + end, + seq = fn a' => loop4 (a', i + 0w2, j), + sync = g, + spwn = fn () => spwn z, + unstolen = SOME spwn + } + end + else + loop1 (a, i, j) + + + + and loop1 (a, i, j) = + if i + 0w1 <= j then + let + fun __inline_never__ spwn a' = + if i + 0w1 >= j then a' else + let + val mid = midpoint (i + 0w1, j) + in + ForkJoin.spork { + tokenPolicy = ForkJoin.TokenPolicyFair, + body = fn () => loop1 (a', i + 0w1, mid), + spwn = fn () => loop1 (z, mid, j), + seq = fn a'' => loop1 (a'', mid, j), + sync = g, + unstolen = NONE + } + end + in + ForkJoin.spork { + tokenPolicy = ForkJoin.TokenPolicyGive, + body = fn () => step (a, i), + seq = fn a' => loop2 (a', i + 0w1, j), + sync = g, + spwn = fn () => spwn z, + unstolen = SOME spwn + } + end + else + a + + in + __inline_always__ + loop1 (z, i2w lo, i2w hi) + end + +end \ No newline at end of file diff --git a/basis-library/schedulers/spork/sources.mlb b/basis-library/schedulers/spork/sources.mlb index 62d3e4520..d168d0cc0 100644 --- a/basis-library/schedulers/spork/sources.mlb +++ b/basis-library/schedulers/spork/sources.mlb @@ -38,6 +38,7 @@ local in Heartbeat.sml Scheduler.sml + UnrolledLoops.sml end ForkJoin.sml in diff --git a/mlton/atoms/prim.fun b/mlton/atoms/prim.fun index 7ed2b0a51..afff7c0b1 100644 --- a/mlton/atoms/prim.fun +++ b/mlton/atoms/prim.fun @@ -113,7 +113,8 @@ datatype 'a t = | MLton_share (* to rssa (as nop or runtime C fn) *) | MLton_size (* to rssa (as runtime C fn) *) | MLton_touch (* to rssa (as nop) or backend (as nop) *) - (* TODO: add isLoop: bool*) + (* TODO: choose between unrolled and regular *) + | Spork_choose | Spork of {tokenSplitPolicy: Word32.word} (* closure convert *) | Spork_forkThreadAndSetData of {youngest: bool} (* to rssa (as runtime C fn) *) | Spork_getData of Spid.t (* backend *) @@ -295,7 +296,8 @@ fun toString (n: 'a t): string = | MLton_share => "MLton_share" | MLton_size => "MLton_size" | MLton_touch => "MLton_touch" - (* TODO: spork {tokenSplitPolicy, isLoop} *) + (* TODO: choose spork *) + | Spork_choose => "spork_choose" | Spork {tokenSplitPolicy=0w0} => "spork_fair" | Spork {tokenSplitPolicy=0w1} => "spork_keep" | Spork {tokenSplitPolicy=0w2} => "spork_give" From 2c8c02c056d7043eb46df66f20f3ab05fd918d80 Mon Sep 17 00:00:00 2001 From: Sundara Vishnu Satish Date: Mon, 29 Sep 2025 18:50:12 -0400 Subject: [PATCH 03/14] feat: unified interface for ForkJoin --- basis-library/schedulers/spork/ForkJoin.sml | 78 ++++++++++++++++++- basis-library/schedulers/spork/Scheduler.sml | 5 +- .../schedulers/spork/UnrolledLoops.sml | 34 ++++---- mlton/atoms/prim.fun | 7 +- mlton/atoms/prim.sig | 3 +- 5 files changed, 103 insertions(+), 24 deletions(-) diff --git a/basis-library/schedulers/spork/ForkJoin.sml b/basis-library/schedulers/spork/ForkJoin.sml index 16623aa21..6fe80b1ca 100644 --- a/basis-library/schedulers/spork/ForkJoin.sml +++ b/basis-library/schedulers/spork/ForkJoin.sml @@ -294,8 +294,80 @@ struct val fIntInf = LoopsInt.parform end) - val pareduce = Pareduce.f + local + (* fallback to regular implementation until runtime supports spork_choose *) + fun primSporkChoose (loopBody, unrolled, regular) = regular + + fun unifiedReducem (combine: 'a * 'a -> 'a) (zero: 'a) (lo: int, hi: int) (f: int -> 'a) : 'a = + let + fun regularImpl () = + let + val pareduce = case Int.precision of + SOME 8 => Loops8.pareduce + | SOME 16 => Loops16.pareduce + | SOME 32 => Loops32.pareduce + | SOME 64 => Loops64.pareduce + | _ => LoopsInt.pareduce + in + pareduce (lo, hi) zero (fn (i, a) => combine (a, f i)) combine + end + + fun unrolledImpl () = + Parfor.pareduce (lo, hi) zero (fn (i, a) => combine (a, f i)) combine + in + primSporkChoose (f, unrolledImpl (), regularImpl ()) + end + + fun unifiedParform (lo: int, hi: int) (f: int -> unit) : unit = + let + fun regularImpl () = + let + val parform = case Int.precision of + SOME 8 => Loops8.parform + | SOME 16 => Loops16.parform + | SOME 32 => Loops32.parform + | SOME 64 => Loops64.parform + | _ => LoopsInt.parform + in + parform (lo, hi) f + end + + fun unrolledImpl () = + let + val _ = Parfor.pareduce (lo, hi) () (fn (i, _) => f i) (fn _ => ()) + in + () + end + in + primSporkChoose (f, unrolledImpl (), regularImpl ()) + end + + fun unifiedPareduce (lo: int, hi: int) (zero: 'a) (step: int * 'a -> 'a) (combine: 'a * 'a -> 'a) : 'a = + let + fun regularImpl () = + let + val pareduce = case Int.precision of + SOME 8 => Loops8.pareduce + | SOME 16 => Loops16.pareduce + | SOME 32 => Loops32.pareduce + | SOME 64 => Loops64.pareduce + | _ => LoopsInt.pareduce + in + pareduce (lo, hi) zero step combine + end + + fun unrolledImpl () = + Parfor.pareduce (lo, hi) zero step combine + + fun loopBody i = step (i, zero) + in + primSporkChoose (loopBody, unrolledImpl (), regularImpl ()) + end + in + val reducem = unifiedReducem + val parform = unifiedParform + val pareduce = unifiedPareduce + end + val pareduceBreakExn = PareduceBreakExn.f - val reducem = Reducem.f - val parform = Parform.f end \ No newline at end of file diff --git a/basis-library/schedulers/spork/Scheduler.sml b/basis-library/schedulers/spork/Scheduler.sml index 7023fd9f9..9a59f1eb0 100644 --- a/basis-library/schedulers/spork/Scheduler.sml +++ b/basis-library/schedulers/spork/Scheduler.sml @@ -103,9 +103,11 @@ struct * (exn -> 'c) (* exn seq *) * (exn * 'd -> 'c) (* exn sync *) -> 'c; + (* TODO: commented out until runtime implements spork_choose val primSporkChoose' = _prim "spork_choose" : ('u -> 'v) -> 'a -> 'a -> 'a + *) fun __inline_always__ primSporkFair (body, spwn, seq, sync, exnseq, exnsync) = __inline_always__ primSporkFair' (body, (), spwn, (), seq, sync, exnseq, exnsync) @@ -113,7 +115,8 @@ struct __inline_always__ primSporkKeep' (body, (), spwn, (), seq, sync, exnseq, exnsync) fun __inline_always__ primSporkGive (body, spwn, seq, sync, exnseq, exnsync) = __inline_always__ primSporkGive' (body, (), spwn, (), seq, sync, exnseq, exnsync) - (* TODO: add primSporkChoose *) + (* TODO: Re-enable after implementing spork_choose *) + fun __inline_always__ primSporkChoose (loopBody, unrolled, regular) = regular val primForkThreadAndSetData = _prim "spork_forkThreadAndSetData": Thread.t * 'a -> Thread.p; val primForkThreadAndSetData_youngest = _prim "spork_forkThreadAndSetData_youngest": Thread.t * 'a -> Thread.p; diff --git a/basis-library/schedulers/spork/UnrolledLoops.sml b/basis-library/schedulers/spork/UnrolledLoops.sml index 398185699..6613f28cb 100644 --- a/basis-library/schedulers/spork/UnrolledLoops.sml +++ b/basis-library/schedulers/spork/UnrolledLoops.sml @@ -1,4 +1,4 @@ -structure Parfor: PARFOR = +structure Parfor = struct type word = Word64.word @@ -35,8 +35,8 @@ struct let val mid = midpoint (i + 0w8, j) in - ForkJoin.spork { - tokenPolicy = ForkJoin.TokenPolicyFair, + Scheduler.SporkJoin.spork { + tokenPolicy = Scheduler.TokenPolicyFair, body = fn () => loop1 (a', i + 0w8, mid), spwn = fn () => loop1 (z, mid, j), seq = fn a'' => loop1 (a'', mid, j), @@ -45,8 +45,8 @@ struct } end in - ForkJoin.spork { - tokenPolicy = ForkJoin.TokenPolicyGive, + Scheduler.SporkJoin.spork { + tokenPolicy = Scheduler.TokenPolicyGive, body = fn () => let val a = step (a, i) @@ -79,8 +79,8 @@ struct let val mid = midpoint (i + 0w4, j) in - ForkJoin.spork { - tokenPolicy = ForkJoin.TokenPolicyFair, + Scheduler.SporkJoin.spork { + tokenPolicy = Scheduler.TokenPolicyFair, body = fn () => loop1 (a', i + 0w4, mid), spwn = fn () => loop1 (z, mid, j), seq = fn a'' => loop1 (a'', mid, j), @@ -89,8 +89,8 @@ struct } end in - ForkJoin.spork { - tokenPolicy = ForkJoin.TokenPolicyGive, + Scheduler.SporkJoin.spork { + tokenPolicy = Scheduler.TokenPolicyGive, body = fn () => let val a = step (a, i) @@ -119,8 +119,8 @@ struct let val mid = midpoint (i + 0w2, j) in - ForkJoin.spork { - tokenPolicy = ForkJoin.TokenPolicyFair, + Scheduler.SporkJoin.spork { + tokenPolicy = Scheduler.TokenPolicyFair, body = fn () => loop1 (a', i + 0w2, mid), spwn = fn () => loop1 (z, mid, j), seq = fn a'' => loop1 (a'', mid, j), @@ -129,8 +129,8 @@ struct } end in - ForkJoin.spork { - tokenPolicy = ForkJoin.TokenPolicyGive, + Scheduler.SporkJoin.spork { + tokenPolicy = Scheduler.TokenPolicyGive, body = fn () => let val a = step (a, i) @@ -157,8 +157,8 @@ struct let val mid = midpoint (i + 0w1, j) in - ForkJoin.spork { - tokenPolicy = ForkJoin.TokenPolicyFair, + Scheduler.SporkJoin.spork { + tokenPolicy = Scheduler.TokenPolicyFair, body = fn () => loop1 (a', i + 0w1, mid), spwn = fn () => loop1 (z, mid, j), seq = fn a'' => loop1 (a'', mid, j), @@ -167,8 +167,8 @@ struct } end in - ForkJoin.spork { - tokenPolicy = ForkJoin.TokenPolicyGive, + Scheduler.SporkJoin.spork { + tokenPolicy = Scheduler.TokenPolicyGive, body = fn () => step (a, i), seq = fn a' => loop2 (a', i + 0w1, j), sync = g, diff --git a/mlton/atoms/prim.fun b/mlton/atoms/prim.fun index afff7c0b1..df66efef6 100644 --- a/mlton/atoms/prim.fun +++ b/mlton/atoms/prim.fun @@ -114,7 +114,7 @@ datatype 'a t = | MLton_size (* to rssa (as runtime C fn) *) | MLton_touch (* to rssa (as nop) or backend (as nop) *) (* TODO: choose between unrolled and regular *) - | Spork_choose + | Spork_choose of {tokenSplitPolicy: Word32.word} (* closure convert *) | Spork of {tokenSplitPolicy: Word32.word} (* closure convert *) | Spork_forkThreadAndSetData of {youngest: bool} (* to rssa (as runtime C fn) *) | Spork_getData of Spid.t (* backend *) @@ -297,7 +297,10 @@ fun toString (n: 'a t): string = | MLton_size => "MLton_size" | MLton_touch => "MLton_touch" (* TODO: choose spork *) - | Spork_choose => "spork_choose" + | Spork_choose {tokenSplitPolicy=0w0} => "spork_choose_fair" + | Spork_choose {tokenSplitPolicy=0w1} => "spork_choose_keep" + | Spork_choose {tokenSplitPolicy=0w2} => "spork_choose_give" + | Spork_choose {tokenSplitPolicy=pol} => Error.bug ("unknown spork tokensplitpolicy " ^ Word32.toString pol) | Spork {tokenSplitPolicy=0w0} => "spork_fair" | Spork {tokenSplitPolicy=0w1} => "spork_keep" | Spork {tokenSplitPolicy=0w2} => "spork_give" diff --git a/mlton/atoms/prim.sig b/mlton/atoms/prim.sig index 045636155..b58b457be 100644 --- a/mlton/atoms/prim.sig +++ b/mlton/atoms/prim.sig @@ -104,7 +104,8 @@ signature PRIM = | MLton_share (* to rssa (as nop or runtime C fn) *) | MLton_size (* to rssa (as runtime C fn) *) | MLton_touch (* to rssa (as nop) or backend (as nop) *) - (* TODO: add isLoop bool *) + (* TODO: Check usage properly *) + | Spork_choose of {tokenSplitPolicy: Word32.word} (* closure convert *) | Spork of {tokenSplitPolicy: Word32.word} (* closure convert *) | Spork_forkThreadAndSetData of {youngest: bool} (* to rssa (as runtime C fn) *) | Spork_getData of Spid.t (* backend *) From d1967a81aeed62c1e34a7ee80ae4ed58060902de Mon Sep 17 00:00:00 2001 From: Sundara Vishnu Satish Date: Mon, 29 Sep 2025 18:52:42 -0400 Subject: [PATCH 04/14] fix: forgor choose in prim.fun --- mlton/atoms/prim.fun | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mlton/atoms/prim.fun b/mlton/atoms/prim.fun index df66efef6..7f49aab13 100644 --- a/mlton/atoms/prim.fun +++ b/mlton/atoms/prim.fun @@ -1088,6 +1088,9 @@ in Spork {tokenSplitPolicy = 0w0}, Spork {tokenSplitPolicy = 0w1}, Spork {tokenSplitPolicy = 0w2}, + Spork_choose {tokenSplitPolicy = 0w0}, + Spork_choose {tokenSplitPolicy = 0w1}, + Spork_choose {tokenSplitPolicy = 0w2}, Spork_forkThreadAndSetData {youngest=true}, Spork_forkThreadAndSetData {youngest=false}, (*Spork_getData,*) From 602127a86405da912cdf99f9c3d94c52d0599a6c Mon Sep 17 00:00:00 2001 From: Sundara Vishnu Satish Date: Tue, 30 Sep 2025 11:51:53 -0400 Subject: [PATCH 05/14] fix: choose unrolled - trying to compare performance --- basis-library/schedulers/spork/ForkJoin.sml | 6 +++--- basis-library/schedulers/spork/UnrolledLoops.sml | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/basis-library/schedulers/spork/ForkJoin.sml b/basis-library/schedulers/spork/ForkJoin.sml index 6fe80b1ca..e202595a2 100644 --- a/basis-library/schedulers/spork/ForkJoin.sml +++ b/basis-library/schedulers/spork/ForkJoin.sml @@ -313,7 +313,7 @@ struct end fun unrolledImpl () = - Parfor.pareduce (lo, hi) zero (fn (i, a) => combine (a, f i)) combine + Unrolled.pareduce (lo, hi) zero (fn (i, a) => combine (a, f i)) combine in primSporkChoose (f, unrolledImpl (), regularImpl ()) end @@ -334,7 +334,7 @@ struct fun unrolledImpl () = let - val _ = Parfor.pareduce (lo, hi) () (fn (i, _) => f i) (fn _ => ()) + val _ = Unrolled.pareduce (lo, hi) () (fn (i, _) => f i) (fn _ => ()) in () end @@ -357,7 +357,7 @@ struct end fun unrolledImpl () = - Parfor.pareduce (lo, hi) zero step combine + Unrolled.pareduce (lo, hi) zero step combine fun loopBody i = step (i, zero) in diff --git a/basis-library/schedulers/spork/UnrolledLoops.sml b/basis-library/schedulers/spork/UnrolledLoops.sml index 6613f28cb..3b04c8e20 100644 --- a/basis-library/schedulers/spork/UnrolledLoops.sml +++ b/basis-library/schedulers/spork/UnrolledLoops.sml @@ -1,4 +1,4 @@ -structure Parfor = +structure Unrolled = struct type word = Word64.word From 0db76124b9c3d99413c7e1b01a08efbcf321a288 Mon Sep 17 00:00:00 2001 From: Sundara Vishnu Satish Date: Tue, 30 Sep 2025 13:11:49 -0400 Subject: [PATCH 06/14] =?UTF-8?q?fix:=20=F0=9F=98=A1=20choose=20unrolled?= =?UTF-8?q?=20-=20trying=20to=20compare=20performance?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- basis-library/schedulers/spork/ForkJoin.sml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/basis-library/schedulers/spork/ForkJoin.sml b/basis-library/schedulers/spork/ForkJoin.sml index e202595a2..b7261b10c 100644 --- a/basis-library/schedulers/spork/ForkJoin.sml +++ b/basis-library/schedulers/spork/ForkJoin.sml @@ -296,7 +296,7 @@ struct local (* fallback to regular implementation until runtime supports spork_choose *) - fun primSporkChoose (loopBody, unrolled, regular) = regular + fun primSporkChoose (loopBody, unrolled, regular) = unrolled fun unifiedReducem (combine: 'a * 'a -> 'a) (zero: 'a) (lo: int, hi: int) (f: int -> 'a) : 'a = let From df27f49038b42544fcc841e76b598279e503c149 Mon Sep 17 00:00:00 2001 From: Sundara Vishnu Satish Date: Tue, 30 Sep 2025 13:15:33 -0400 Subject: [PATCH 07/14] =?UTF-8?q?fix:=20=F0=9F=98=AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- basis-library/schedulers/spork/Scheduler.sml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/basis-library/schedulers/spork/Scheduler.sml b/basis-library/schedulers/spork/Scheduler.sml index 9a59f1eb0..80de95c2f 100644 --- a/basis-library/schedulers/spork/Scheduler.sml +++ b/basis-library/schedulers/spork/Scheduler.sml @@ -116,7 +116,7 @@ struct fun __inline_always__ primSporkGive (body, spwn, seq, sync, exnseq, exnsync) = __inline_always__ primSporkGive' (body, (), spwn, (), seq, sync, exnseq, exnsync) (* TODO: Re-enable after implementing spork_choose *) - fun __inline_always__ primSporkChoose (loopBody, unrolled, regular) = regular + fun __inline_always__ primSporkChoose (loopBody, unrolled, regular) = unrolled val primForkThreadAndSetData = _prim "spork_forkThreadAndSetData": Thread.t * 'a -> Thread.p; val primForkThreadAndSetData_youngest = _prim "spork_forkThreadAndSetData_youngest": Thread.t * 'a -> Thread.p; From 0f39d01bdf13d1d0f8dda02b00de635bb6123c4f Mon Sep 17 00:00:00 2001 From: Sundara Vishnu Satish Date: Tue, 30 Sep 2025 19:46:31 -0400 Subject: [PATCH 08/14] fix: not call both versions, inlining --- basis-library/schedulers/spork/ForkJoin.sml | 39 ++++++++++----------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/basis-library/schedulers/spork/ForkJoin.sml b/basis-library/schedulers/spork/ForkJoin.sml index b7261b10c..e285708a7 100644 --- a/basis-library/schedulers/spork/ForkJoin.sml +++ b/basis-library/schedulers/spork/ForkJoin.sml @@ -296,12 +296,13 @@ struct local (* fallback to regular implementation until runtime supports spork_choose *) - fun primSporkChoose (loopBody, unrolled, regular) = unrolled + fun __inline_always__ primSporkChoose (loopBody, unrolled, regular) = __inline_always__ unrolled () fun unifiedReducem (combine: 'a * 'a -> 'a) (zero: 'a) (lo: int, hi: int) (f: int -> 'a) : 'a = let - fun regularImpl () = + fun __inline_always__ regularImpl () = let + (* TODO: look into this piece *) val pareduce = case Int.precision of SOME 8 => Loops8.pareduce | SOME 16 => Loops16.pareduce @@ -312,15 +313,15 @@ struct pareduce (lo, hi) zero (fn (i, a) => combine (a, f i)) combine end - fun unrolledImpl () = - Unrolled.pareduce (lo, hi) zero (fn (i, a) => combine (a, f i)) combine + fun __inline_always__ unrolledImpl () = + __inline_always__ Unrolled.pareduce (lo, hi) zero (fn (i, a) => combine (a, f i)) combine in - primSporkChoose (f, unrolledImpl (), regularImpl ()) + primSporkChoose (__inline_always__ f, __inline_always__ unrolledImpl, __inline_always__ regularImpl) end fun unifiedParform (lo: int, hi: int) (f: int -> unit) : unit = let - fun regularImpl () = + fun __inline_always__ regularImpl () = let val parform = case Int.precision of SOME 8 => Loops8.parform @@ -332,19 +333,16 @@ struct parform (lo, hi) f end - fun unrolledImpl () = - let - val _ = Unrolled.pareduce (lo, hi) () (fn (i, _) => f i) (fn _ => ()) - in - () - end + fun __inline_always__ unrolledImpl () = + __inline_always__ Unrolled.pareduce (lo, hi) () (fn (i, _) => f i) (fn _ => ()) + in - primSporkChoose (f, unrolledImpl (), regularImpl ()) + primSporkChoose (__inline_always__ f, __inline_always__ unrolledImpl, __inline_always__ regularImpl) end - fun unifiedPareduce (lo: int, hi: int) (zero: 'a) (step: int * 'a -> 'a) (combine: 'a * 'a -> 'a) : 'a = + fun __inline_always__ unifiedPareduce (lo: int, hi: int) (zero: 'a) (step: int * 'a -> 'a) (combine: 'a * 'a -> 'a) : 'a = let - fun regularImpl () = + fun __inline_always__ regularImpl () = let val pareduce = case Int.precision of SOME 8 => Loops8.pareduce @@ -356,18 +354,19 @@ struct pareduce (lo, hi) zero step combine end - fun unrolledImpl () = - Unrolled.pareduce (lo, hi) zero step combine + fun __inline_always__ unrolledImpl () = + __inline_always__ Unrolled.pareduce (lo, hi) zero step combine - fun loopBody i = step (i, zero) + fun __inline_always__ loopBody i = step (i, zero) in - primSporkChoose (loopBody, unrolledImpl (), regularImpl ()) + primSporkChoose (__inline_always__ loopBody, __inline_always__ unrolledImpl, __inline_always__ regularImpl) end in val reducem = unifiedReducem val parform = unifiedParform val pareduce = unifiedPareduce + val parfor = ForkJoin0.parfor end val pareduceBreakExn = PareduceBreakExn.f -end \ No newline at end of file +end From c3c757fd6f34769df82419b82af9852cd0b90974 Mon Sep 17 00:00:00 2001 From: Sundara Vishnu Satish Date: Sun, 5 Oct 2025 23:11:38 -0400 Subject: [PATCH 09/14] =?UTF-8?q?chore:=20introduce=20new=20prim,=20compil?= =?UTF-8?q?es=20but=20bricked=20=F0=9F=98=94?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- basis-library/schedulers/spork/ForkJoin.sml | 44 ++++++++-- basis-library/schedulers/spork/Scheduler.sml | 11 +-- .../schedulers/spork/UnrolledLoops.sml | 87 ++++++++++--------- mlton/atoms/prim.fun | 50 +++++++---- mlton/atoms/prim.sig | 3 +- mlton/closure-convert/closure-convert.fun | 52 ++++++++++- 6 files changed, 176 insertions(+), 71 deletions(-) diff --git a/basis-library/schedulers/spork/ForkJoin.sml b/basis-library/schedulers/spork/ForkJoin.sml index e285708a7..766db6b1c 100644 --- a/basis-library/schedulers/spork/ForkJoin.sml +++ b/basis-library/schedulers/spork/ForkJoin.sml @@ -3,6 +3,7 @@ struct datatype TokenPolicy = datatype Scheduler.TokenPolicy val spork = Scheduler.SporkJoin.spork + val primSporkChoose = Scheduler.primSporkChoose fun par (f: unit -> 'a, g: unit -> 'b): 'a * 'b = spork { @@ -254,6 +255,12 @@ struct val equal = op= end) + (* TODO: Need to improve this interface *) + structure Unrolled8 = UnrolledLoops(Word8) + structure Unrolled16 = UnrolledLoops(Word16) + structure Unrolled32 = UnrolledLoops(Word32) + structure Unrolled64 = UnrolledLoops(Word64) + structure Pareduce = Int_ChooseFromInt (struct type 'a t = (int * int) -> 'a -> (int * 'a -> 'a) -> ('a * 'a -> 'a) -> 'a @@ -295,14 +302,10 @@ struct end) local - (* fallback to regular implementation until runtime supports spork_choose *) - fun __inline_always__ primSporkChoose (loopBody, unrolled, regular) = __inline_always__ unrolled () - fun unifiedReducem (combine: 'a * 'a -> 'a) (zero: 'a) (lo: int, hi: int) (f: int -> 'a) : 'a = let fun __inline_always__ regularImpl () = let - (* TODO: look into this piece *) val pareduce = case Int.precision of SOME 8 => Loops8.pareduce | SOME 16 => Loops16.pareduce @@ -314,7 +317,16 @@ struct end fun __inline_always__ unrolledImpl () = - __inline_always__ Unrolled.pareduce (lo, hi) zero (fn (i, a) => combine (a, f i)) combine + let + val pareduce = case Int.precision of + SOME 8 => Unrolled8.pareduce + | SOME 16 => Unrolled16.pareduce + | SOME 32 => Unrolled32.pareduce + | SOME 64 => Unrolled64.pareduce + | _ => Unrolled64.pareduce (* fallback to 64-bit for IntInf *) + in + __inline_always__ pareduce (lo, hi) zero (fn (i, a) => combine (a, f i)) combine + end in primSporkChoose (__inline_always__ f, __inline_always__ unrolledImpl, __inline_always__ regularImpl) end @@ -334,7 +346,16 @@ struct end fun __inline_always__ unrolledImpl () = - __inline_always__ Unrolled.pareduce (lo, hi) () (fn (i, _) => f i) (fn _ => ()) + let + val pareduce = case Int.precision of + SOME 8 => Unrolled8.pareduce + | SOME 16 => Unrolled16.pareduce + | SOME 32 => Unrolled32.pareduce + | SOME 64 => Unrolled64.pareduce + | _ => Unrolled64.pareduce (* fallback to 64-bit for IntInf *) + in + __inline_always__ pareduce (lo, hi) () (fn (i, _) => f i) (fn _ => ()) + end in primSporkChoose (__inline_always__ f, __inline_always__ unrolledImpl, __inline_always__ regularImpl) @@ -355,7 +376,16 @@ struct end fun __inline_always__ unrolledImpl () = - __inline_always__ Unrolled.pareduce (lo, hi) zero step combine + let + val pareduce = case Int.precision of + SOME 8 => Unrolled8.pareduce + | SOME 16 => Unrolled16.pareduce + | SOME 32 => Unrolled32.pareduce + | SOME 64 => Unrolled64.pareduce + | _ => Unrolled64.pareduce (* fallback to 64-bit for IntInf *) + in + __inline_always__ pareduce (lo, hi) zero step combine + end fun __inline_always__ loopBody i = step (i, zero) in diff --git a/basis-library/schedulers/spork/Scheduler.sml b/basis-library/schedulers/spork/Scheduler.sml index 80de95c2f..0fa517cdb 100644 --- a/basis-library/schedulers/spork/Scheduler.sml +++ b/basis-library/schedulers/spork/Scheduler.sml @@ -103,11 +103,12 @@ struct * (exn -> 'c) (* exn seq *) * (exn * 'd -> 'c) (* exn sync *) -> 'c; - (* TODO: commented out until runtime implements spork_choose val primSporkChoose' = _prim "spork_choose" - : ('u -> 'v) -> 'a -> 'a -> 'a - *) + : ('u -> 'a) (* loop body *) + * (unit -> 'a) (* unrolled implementation *) + * (unit -> 'a) (* regular implementation *) + -> 'a; fun __inline_always__ primSporkFair (body, spwn, seq, sync, exnseq, exnsync) = __inline_always__ primSporkFair' (body, (), spwn, (), seq, sync, exnseq, exnsync) @@ -115,8 +116,8 @@ struct __inline_always__ primSporkKeep' (body, (), spwn, (), seq, sync, exnseq, exnsync) fun __inline_always__ primSporkGive (body, spwn, seq, sync, exnseq, exnsync) = __inline_always__ primSporkGive' (body, (), spwn, (), seq, sync, exnseq, exnsync) - (* TODO: Re-enable after implementing spork_choose *) - fun __inline_always__ primSporkChoose (loopBody, unrolled, regular) = unrolled + fun __inline_always__ primSporkChoose (loopBody, unrolled, regular) = + __inline_always__ primSporkChoose' (loopBody, unrolled, regular) val primForkThreadAndSetData = _prim "spork_forkThreadAndSetData": Thread.t * 'a -> Thread.p; val primForkThreadAndSetData_youngest = _prim "spork_forkThreadAndSetData_youngest": Thread.t * 'a -> Thread.p; diff --git a/basis-library/schedulers/spork/UnrolledLoops.sml b/basis-library/schedulers/spork/UnrolledLoops.sml index 3b04c8e20..fa5944fa4 100644 --- a/basis-library/schedulers/spork/UnrolledLoops.sml +++ b/basis-library/schedulers/spork/UnrolledLoops.sml @@ -1,12 +1,21 @@ -structure Unrolled = +functor UnrolledLoops(WordImpl: WORD) = struct - type word = Word64.word - fun __inline_always__ w2i w = __inline_always__ Word64.toIntX w - fun __inline_always__ i2w i = __inline_always__ Word64.fromInt i + type word = WordImpl.word + fun __inline_always__ w2i w = __inline_always__ WordImpl.toIntX w + fun __inline_always__ i2w i = __inline_always__ WordImpl.fromInt i + + val one = __inline_always__ WordImpl.fromInt 1 + val two = __inline_always__ WordImpl.fromInt 2 + val three = __inline_always__ WordImpl.fromInt 3 + val four = __inline_always__ WordImpl.fromInt 4 + val five = __inline_always__ WordImpl.fromInt 5 + val six = __inline_always__ WordImpl.fromInt 6 + val seven = __inline_always__ WordImpl.fromInt 7 + val eight = __inline_always__ WordImpl.fromInt 8 fun __inline_always__ midpoint (i: word, j: word) = - i + (Word64.>> (j - i, 0w1)) + WordImpl.+ (i, WordImpl.>> (WordImpl.- (j, i), 0w1)) fun __inline_always__ pareduce (lo, hi) (z: 'a) (step': int * 'a -> 'a) (g: 'a * 'a -> 'a) : 'a = @@ -26,18 +35,18 @@ struct (* fun __inline_always__ next stride = Word64.min (Word64.<< (stride, 0w1), 0w16) *) - + fun loop8 (a, i, j) = - if i + 0w8 <= j then + if WordImpl.<= (WordImpl.+ (i, eight), j) then let fun __inline_never__ spwn a' = - if i + 0w8 >= j then a' else + if WordImpl.>= (WordImpl.+ (i, eight), j) then a' else let - val mid = midpoint (i + 0w8, j) + val mid = midpoint (WordImpl.+ (i, eight), j) in Scheduler.SporkJoin.spork { tokenPolicy = Scheduler.TokenPolicyFair, - body = fn () => loop1 (a', i + 0w8, mid), + body = fn () => loop1 (a', WordImpl.+ (i, eight), mid), spwn = fn () => loop1 (z, mid, j), seq = fn a'' => loop1 (a'', mid, j), sync = g, @@ -50,17 +59,17 @@ struct body = fn () => let val a = step (a, i) - val a = step (a, i+0w1) - val a = step (a, i+0w2) - val a = step (a, i+0w3) - val a = step (a, i+0w4) - val a = step (a, i+0w5) - val a = step (a, i+0w6) - val a = step (a, i+0w7) + val a = step (a, WordImpl.+ (i, one)) + val a = step (a, WordImpl.+ (i, two)) + val a = step (a, WordImpl.+ (i, three)) + val a = step (a, WordImpl.+ (i, four)) + val a = step (a, WordImpl.+ (i, five)) + val a = step (a, WordImpl.+ (i, six)) + val a = step (a, WordImpl.+ (i, seven)) in a end, - seq = fn a' => loop8 (a', i + 0w8, j), + seq = fn a' => loop8 (a', WordImpl.+ (i, eight), j), sync = g, spwn = fn () => spwn z, unstolen = SOME spwn @@ -72,16 +81,16 @@ struct and loop4 (a, i, j) = - if i + 0w4 <= j then + if WordImpl.<= (WordImpl.+ (i, four), j) then let fun __inline_never__ spwn a' = - if i + 0w4 >= j then a' else + if WordImpl.>= (WordImpl.+ (i, four), j) then a' else let - val mid = midpoint (i + 0w4, j) + val mid = midpoint (WordImpl.+ (i, four), j) in Scheduler.SporkJoin.spork { tokenPolicy = Scheduler.TokenPolicyFair, - body = fn () => loop1 (a', i + 0w4, mid), + body = fn () => loop1 (a', WordImpl.+ (i, four), mid), spwn = fn () => loop1 (z, mid, j), seq = fn a'' => loop1 (a'', mid, j), sync = g, @@ -94,13 +103,13 @@ struct body = fn () => let val a = step (a, i) - val a = step (a, i+0w1) - val a = step (a, i+0w2) - val a = step (a, i+0w3) + val a = step (a, WordImpl.+ (i, one)) + val a = step (a, WordImpl.+ (i, two)) + val a = step (a, WordImpl.+ (i, three)) in a end, - seq = fn a' => loop8 (a', i + 0w4, j), + seq = fn a' => loop8 (a', WordImpl.+ (i, four), j), sync = g, spwn = fn () => spwn z, unstolen = SOME spwn @@ -110,18 +119,18 @@ struct loop1 (a, i, j) - + and loop2 (a, i, j) = - if i + 0w2 <= j then + if WordImpl.<= (WordImpl.+ (i, two), j) then let fun __inline_never__ spwn a' = - if i + 0w2 >= j then a' else + if WordImpl.>= (WordImpl.+ (i, two), j) then a' else let - val mid = midpoint (i + 0w2, j) + val mid = midpoint (WordImpl.+ (i, two), j) in Scheduler.SporkJoin.spork { tokenPolicy = Scheduler.TokenPolicyFair, - body = fn () => loop1 (a', i + 0w2, mid), + body = fn () => loop1 (a', WordImpl.+ (i, two), mid), spwn = fn () => loop1 (z, mid, j), seq = fn a'' => loop1 (a'', mid, j), sync = g, @@ -134,11 +143,11 @@ struct body = fn () => let val a = step (a, i) - val a = step (a, i+0w1) + val a = step (a, WordImpl.+ (i, one)) in a end, - seq = fn a' => loop4 (a', i + 0w2, j), + seq = fn a' => loop4 (a', WordImpl.+ (i, two), j), sync = g, spwn = fn () => spwn z, unstolen = SOME spwn @@ -148,18 +157,18 @@ struct loop1 (a, i, j) - + and loop1 (a, i, j) = - if i + 0w1 <= j then + if WordImpl.<= (WordImpl.+ (i, one), j) then let fun __inline_never__ spwn a' = - if i + 0w1 >= j then a' else + if WordImpl.>= (WordImpl.+ (i, one), j) then a' else let - val mid = midpoint (i + 0w1, j) + val mid = midpoint (WordImpl.+ (i, one), j) in Scheduler.SporkJoin.spork { tokenPolicy = Scheduler.TokenPolicyFair, - body = fn () => loop1 (a', i + 0w1, mid), + body = fn () => loop1 (a', WordImpl.+ (i, one), mid), spwn = fn () => loop1 (z, mid, j), seq = fn a'' => loop1 (a'', mid, j), sync = g, @@ -170,7 +179,7 @@ struct Scheduler.SporkJoin.spork { tokenPolicy = Scheduler.TokenPolicyGive, body = fn () => step (a, i), - seq = fn a' => loop2 (a', i + 0w1, j), + seq = fn a' => loop2 (a', WordImpl.+ (i, one), j), sync = g, spwn = fn () => spwn z, unstolen = SOME spwn diff --git a/mlton/atoms/prim.fun b/mlton/atoms/prim.fun index 7f49aab13..6cd58d496 100644 --- a/mlton/atoms/prim.fun +++ b/mlton/atoms/prim.fun @@ -113,8 +113,8 @@ datatype 'a t = | MLton_share (* to rssa (as nop or runtime C fn) *) | MLton_size (* to rssa (as runtime C fn) *) | MLton_touch (* to rssa (as nop) or backend (as nop) *) - (* TODO: choose between unrolled and regular *) - | Spork_choose of {tokenSplitPolicy: Word32.word} (* closure convert *) + (* Choose between unrolled and regular at compile time *) + | Spork_choose (* closure convert *) | Spork of {tokenSplitPolicy: Word32.word} (* closure convert *) | Spork_forkThreadAndSetData of {youngest: bool} (* to rssa (as runtime C fn) *) | Spork_getData of Spid.t (* backend *) @@ -296,11 +296,7 @@ fun toString (n: 'a t): string = | MLton_share => "MLton_share" | MLton_size => "MLton_size" | MLton_touch => "MLton_touch" - (* TODO: choose spork *) - | Spork_choose {tokenSplitPolicy=0w0} => "spork_choose_fair" - | Spork_choose {tokenSplitPolicy=0w1} => "spork_choose_keep" - | Spork_choose {tokenSplitPolicy=0w2} => "spork_choose_give" - | Spork_choose {tokenSplitPolicy=pol} => Error.bug ("unknown spork tokensplitpolicy " ^ Word32.toString pol) + | Spork_choose => "spork_choose" | Spork {tokenSplitPolicy=0w0} => "spork_fair" | Spork {tokenSplitPolicy=0w1} => "spork_keep" | Spork {tokenSplitPolicy=0w2} => "spork_give" @@ -469,7 +465,8 @@ val equals: 'a t * 'a t -> bool = | (MLton_share, MLton_share) => true | (MLton_size, MLton_size) => true | (MLton_touch, MLton_touch) => true - (* TODO: isLoop1 = isLoop2 *) + (* TODO: Check usage properly *) + | (Spork_choose, Spork_choose) => true | (Spork {tokenSplitPolicy = tsp1}, Spork {tokenSplitPolicy = tsp2}) => tsp1 = tsp2 | (Spork_forkThreadAndSetData yo1, Spork_forkThreadAndSetData yo2) => yo1 = yo2 | (Spork_getData spid, Spork_getData spid') => Spid.equals (spid, spid') @@ -655,8 +652,9 @@ val map: 'a t * ('a -> 'b) -> 'b t = | MLton_share => MLton_share | MLton_size => MLton_size | MLton_touch => MLton_touch - (* TODO: spork tsp isLoop *) | Spork tsp => Spork tsp + (* TODO: Check usage properly *) + | Spork_choose => Spork_choose | Spork_forkThreadAndSetData z => Spork_forkThreadAndSetData z | Spork_getData spid => Spork_getData spid | Real_Math_acos z => Real_Math_acos z @@ -872,8 +870,9 @@ val kind: 'a t -> Kind.t = | MLton_share => SideEffect | MLton_size => DependsOnState | MLton_touch => SideEffect - (* TODO *) | Spork _ => SideEffect + (* TODO: Check usage properly *) + | Spork_choose => Functional | Spork_forkThreadAndSetData _ => SideEffect | Spork_getData _ => DependsOnState | Real_Math_acos _ => DependsOnState (* depends on rounding mode *) @@ -1084,13 +1083,11 @@ in MLton_share, MLton_size, MLton_touch, - (* TODO: Spork {tokenSplitPolicy = ..., isLoop = ...} *) Spork {tokenSplitPolicy = 0w0}, Spork {tokenSplitPolicy = 0w1}, Spork {tokenSplitPolicy = 0w2}, - Spork_choose {tokenSplitPolicy = 0w0}, - Spork_choose {tokenSplitPolicy = 0w1}, - Spork_choose {tokenSplitPolicy = 0w2}, + (* TODO: Check usage properly *) + Spork_choose, Spork_forkThreadAndSetData {youngest=true}, Spork_forkThreadAndSetData {youngest=false}, (*Spork_getData,*) @@ -1293,6 +1290,9 @@ fun 'a checkApp (prim: 'a t, fun oneTarg f = 1 = Vector.length targs andalso done (f (targ 0)) + fun twoTargs f = + 2 = Vector.length targs + andalso done (f (targ 0, targ 1)) fun sixTargs f = 6 = Vector.length targs andalso done (f (targ 0, targ 1, targ 2, targ 3, targ 4, targ 5)) @@ -1442,7 +1442,6 @@ fun 'a checkApp (prim: 'a t, | MLton_size => oneTarg (fn t => (oneArg t, csize)) | MLton_touch => oneTarg (fn t => (oneArg t, unit)) | Spork _ => - (* TODO: Add isLoop arg -> sevenTargs and nineTargs? *) (* spork: ('aa -> 'ar) * 'aa * ('ba * 'd -> 'br) * 'ba * ('ar -> 'c) * ('ar * 'd -> 'c) * (exn -> 'c) * (exn * 'd -> 'c) -> 'c *) (* ('aa -> 'ar) * 'aa * ('ba * 'd -> 'br) * 'ba * ('ar -> 'c) * ('ar * 'd -> 'c) * (exn -> 'c) * (exn * 'd -> 'c) -> 'c *) sixTargs (fn (taa, tar, tba, tbr, td, tc) => @@ -1456,6 +1455,17 @@ fun 'a checkApp (prim: 'a t, in (eightArgs (cont, taa, spwn, tba, seq, sync, exnseq, exnsync), tc) end) + | Spork_choose => + (* TODO: Check usage properly *) + (* spork_choose: ('u -> 'v) -> (unit -> 'a) -> (unit -> 'a) -> 'a + * where 'v = 'a, so really: ('u -> 'a) -> (unit -> 'a) -> (unit -> 'a) -> 'a *) + twoTargs (fn (ta, tu) => + let + val loopBody = arrow (tu, ta) (* First arg: loop body function 'u -> 'a *) + val impl = arrow (unit, ta) (* Second and third args: thunks unit -> 'a *) + in + (threeArgs (loopBody, impl, impl), ta) + end) | Spork_forkThreadAndSetData _ => oneTarg (fn t => (twoArgs (thread, t), thread)) | Spork_getData _ => oneTarg (fn t => (noArgs, t)) | Real_Math_acos s => realUnary s @@ -1600,7 +1610,6 @@ fun ('a, 'b) extractTargs (prim: 'b t, | MLton_size => one (arg 0) | MLton_touch => one (arg 0) | Spork _ => - (* TODO: add isLoop *) (* spork: ('aa -> 'ar) * 'aa * ('ba * 'd -> 'br) * 'ba * ('ar -> 'c) * ('ar * 'd -> 'c) * (exn -> 'c) * (exn * 'd -> 'c) -> 'c *) let val (taa, tar) = deArrow (arg 0) @@ -1612,6 +1621,15 @@ fun ('a, 'b) extractTargs (prim: 'b t, in six (taa, tar, tba, tbr, td, tc) end + | Spork_choose => + (* TODO: Check usage properly *) + (* spork_choose: ('u -> 'v) -> 'a -> 'a -> 'a *) + let + val ta = result (* Result type 'a *) + val (tu, _) = deArrow (arg 0) (* First arg: loop body ('u -> 'v) *) + in + Vector.new2 (ta, tu) + end | Spork_forkThreadAndSetData _ => one (arg 1) | Spork_getData _ => one result | Ref_assign _ => one (deRef (arg 0)) diff --git a/mlton/atoms/prim.sig b/mlton/atoms/prim.sig index b58b457be..9188455d4 100644 --- a/mlton/atoms/prim.sig +++ b/mlton/atoms/prim.sig @@ -104,8 +104,7 @@ signature PRIM = | MLton_share (* to rssa (as nop or runtime C fn) *) | MLton_size (* to rssa (as runtime C fn) *) | MLton_touch (* to rssa (as nop) or backend (as nop) *) - (* TODO: Check usage properly *) - | Spork_choose of {tokenSplitPolicy: Word32.word} (* closure convert *) + | Spork_choose (* TODO: closure convert / SSA / SSA2 / RSSA ? *) | Spork of {tokenSplitPolicy: Word32.word} (* closure convert *) | Spork_forkThreadAndSetData of {youngest: bool} (* to rssa (as runtime C fn) *) | Spork_getData of Spid.t (* backend *) diff --git a/mlton/closure-convert/closure-convert.fun b/mlton/closure-convert/closure-convert.fun index f3afa3546..a74befab9 100644 --- a/mlton/closure-convert/closure-convert.fun +++ b/mlton/closure-convert/closure-convert.fun @@ -158,7 +158,6 @@ val convertPrimExpInfo = Trace.info "ClosureConvert.convertPrimExp" val valueTypeInfo = Trace.info "ClosureConvert.valueType" structure LambdaFree = LambdaFree (Sxml) -(* # NOTE: Free variable analysis *) local open LambdaFree @@ -477,6 +476,13 @@ fun closureConvert in () end + (* ! THIS IS OBVIOUSLY BRICKED *) + (* TODO: Check usage properly *) + | PrimApp {prim = Prim.Spork_choose, targs, args} => + (* spork_choose: ('u -> 'v) -> (unit -> 'a) -> (unit -> 'a) -> 'a + * Don't try to apply primApply with function arguments; just create + * a fresh abstract value of the result type *) + new' () | PrimApp {prim, args, ...} => set (Value.primApply {prim = prim, args = varExps args, @@ -1102,7 +1108,6 @@ fun closureConvert args = Vector.new1 (lambdaInfoTuple info)}, ac) end - (* TODO: isLoop bool *) | SprimExp.PrimApp {prim = Prim.Spork {tokenSplitPolicy}, targs, args} => (* spork: ('aa -> 'ar) * 'aa * ('ba * 'd -> 'br) * 'bb * ('ar -> 'c) * ('ar * 'd -> 'c) -> 'c *) let @@ -1257,6 +1262,49 @@ fun closureConvert in (exp, ac) end + (* TODO: Check usage properly *) + (* ! THIS IS BRICKED *) + | SprimExp.PrimApp {prim = Prim.Spork_choose, targs, args} => + (* spork_choose: ('u -> 'v) -> (unit -> 'a) -> (unit -> 'a) -> 'a + * For now, apply the regular implementation (third arg) to unit *) + let + fun arg i = Vector.sub (args, i) + val regular = arg 2 + val func = varExpInfo regular + val funcVal = VarInfo.value func + val unitExp = Dexp.tuple {exps = Vector.new0 (), + ty = Type.tuple (Vector.new0 ())} + val unitVal = Value.tuple (Vector.new0 ()) + val ty_result = valueType v + val {cons, ...} = valueLambdasInfo funcVal + in + (* Generate application of regular to unit, similar to apply function *) + (Dexp.casee + {test = convertVarInfo func, + ty = ty_result, + default = NONE, + cases = + Dexp.Con + (Vector.map + (cons, fn {lambda, con} => + let + val {arg = param, body, ...} = Slambda.dest lambda + val info as LambdaInfo.T {name, ...} = lambdaInfo lambda + val result = expValue body + val env = (Var.newString "env", lambdaInfoType info) + in {con = con, + args = Vector.new1 env, + body = coerce (Dexp.call + {func = name, + args = Vector.new2 (Dexp.var env, + coerce (unitExp, unitVal, + value param)), + inline = InlineAttr.Auto, + ty = valueType result}, + result, v)} + end))}, + ac) + end | SprimExp.PrimApp {prim, targs, args} => let val prim = Prim.map (prim, convertType) From f8987e8f4589629757df952a73f0d05fad9e64b4 Mon Sep 17 00:00:00 2001 From: Sundara Vishnu Satish Date: Tue, 28 Oct 2025 14:24:24 -0400 Subject: [PATCH 10/14] feat: please don't disappear again --- mlton/closure-convert/closure-convert.fun | 197 +++++++++++++++++++--- 1 file changed, 177 insertions(+), 20 deletions(-) diff --git a/mlton/closure-convert/closure-convert.fun b/mlton/closure-convert/closure-convert.fun index a74befab9..3167545b9 100644 --- a/mlton/closure-convert/closure-convert.fun +++ b/mlton/closure-convert/closure-convert.fun @@ -224,6 +224,97 @@ val traceLoopBind = ("exp", SprimExp.layout exp)], Unit.layout) +(* similar to lambdaSize in polyvariance *) +fun sxmlLambdaSize (l: Slambda.t): int = + let + fun loopExp (e: Sexp.t, n: int): int = + List.fold + (Sexp.decs e, n, fn (d, n) => + case d of + Sdec.MonoVal {exp, ...} => loopPrimExp (exp, n + 1) + | Sdec.PolyVal {exp, ...} => loopExp (exp, n + 1) + | Sdec.Fun {decs, ...} => Vector.fold (decs, n, fn ({lambda, ...}, n) => + loopLambda (lambda, n)) + | Sdec.Exception _ => n + 1) + and loopLambda (l: Slambda.t, n): int = + let val m = loopExp (Slambda.body l, 0) + in m + n + end + and loopPrimExp (e: SprimExp.t, n: int): int = + case e of + SprimExp.Case {cases, default, ...} => + let + val n = n + 1 + in + Scases.fold + (cases, + (case default of + NONE => n + | SOME e => loopExp (e, n)), + loopExp) + end + | SprimExp.Handle {try, handler, ...} => + loopExp (try, loopExp (handler, n + 1)) + | SprimExp.Lambda l => loopLambda (l, n + 1) + | SprimExp.Profile _ => n + | _ => n + 1 + in + loopExp (Slambda.body l, 0) + end + +fun analyzeLambdaBody (l: Slambda.t): unit = + let + val counts = ref [] + fun addCount (name: string) = List.push (counts, name) + + fun loopExp (e: Sexp.t): unit = + List.foreach + (Sexp.decs e, fn d => + case d of + Sdec.MonoVal {exp, var, ...} => + (addCount (concat ["MonoVal(", Var.toString var, ")"]) + ; loopPrimExp exp) + | Sdec.PolyVal {exp, ...} => + (addCount "PolyVal" + ; loopExp exp) + | Sdec.Fun {decs, ...} => + (addCount (concat ["Fun(", Int.toString (Vector.length decs), " lambdas)"]) + ; Vector.foreach (decs, fn {lambda, ...} => loopLambda lambda)) + | Sdec.Exception _ => addCount "Exception") + and loopLambda (l: Slambda.t): unit = + (addCount "Lambda" + ; loopExp (Slambda.body l)) + and loopPrimExp (e: SprimExp.t): unit = + case e of + SprimExp.App {func, ...} => + addCount (concat ["App(", Layout.toString (SvarExp.layout func), ")"]) + | SprimExp.Case _ => (addCount "Case"; (* would need to recurse into cases *) ()) + | SprimExp.ConApp {con, ...} => + addCount (concat ["ConApp(", Con.toString con, ")"]) + | SprimExp.Const c => addCount (concat ["Const(", Const.toString c, ")"]) + | SprimExp.Handle _ => (addCount "Handle"; (* would need to recurse *) ()) + | SprimExp.Lambda l => loopLambda l + | SprimExp.PrimApp {prim, ...} => + addCount (concat ["PrimApp(", Prim.toString prim, ")"]) + | SprimExp.Profile _ => () + | SprimExp.Raise _ => addCount "Raise" + | SprimExp.Select {offset, ...} => + addCount (concat ["Select(#", Int.toString offset, ")"]) + | SprimExp.Tuple xs => + addCount (concat ["Tuple(", Int.toString (Vector.length xs), ")"]) + | SprimExp.Var x => addCount (concat ["Var(", Layout.toString (SvarExp.layout x), ")"]) + val _ = loopExp (Slambda.body l) + val _ = Control.diagnostics + (fn display => + List.foreach + (rev (!counts), fn name => + display (let open Layout + in seq [str " - ", str name] + end))) + in + () + end + fun closureConvert (program as Sxml.Program.T {datatypes, body}): Ssa.Program.t = let @@ -476,13 +567,28 @@ fun closureConvert in () end - (* ! THIS IS OBVIOUSLY BRICKED *) - (* TODO: Check usage properly *) - | PrimApp {prim = Prim.Spork_choose, targs, args} => - (* spork_choose: ('u -> 'v) -> (unit -> 'a) -> (unit -> 'a) -> 'a - * Don't try to apply primApply with function arguments; just create - * a fresh abstract value of the result type *) - new' () + | PrimApp {prim = Prim.Spork_choose, targs, args} => + (* spork_choose: ('u -> 'a) -> (unit -> 'a) -> (unit -> 'a) -> 'a + * model as applying regular to unit *) + let + fun arg i = Vector.sub (args, i) + val regular = arg 2 + val unitTy = Stype.tuple (Vector.new0 ()) + val unitArg = Value.fromType unitTy + val result = new () + val _ = Value.addHandler + (varExp regular, fn l => + let + val lambda = Value.Lambda.dest l + val {arg = formal, body, ...} = Lambda.dest lambda + in + Value.coerce {from = unitArg, to = value formal} + ; Value.coerce {from = expValue body, to = result} + (* ! body is not really needed for result here *) + end) + in + () + end | PrimApp {prim, args, ...} => set (Value.primApply {prim = prim, args = varExps args, @@ -819,7 +925,7 @@ fun closureConvert cases = Dexp.Con (Vector.map - (cons, fn {lambda, con} => + (cons, fn {lambda, con} => let val {arg = param, body, ...} = Slambda.dest lambda val info as LambdaInfo.T {name, ...} = lambdaInfo lambda @@ -1262,26 +1368,77 @@ fun closureConvert in (exp, ac) end - (* TODO: Check usage properly *) - (* ! THIS IS BRICKED *) | SprimExp.PrimApp {prim = Prim.Spork_choose, targs, args} => - (* spork_choose: ('u -> 'v) -> (unit -> 'a) -> (unit -> 'a) -> 'a - * For now, apply the regular implementation (third arg) to unit *) + (* spork_choose: ('u -> 'a) -> (unit -> 'a) -> (unit -> 'a) -> 'a *) let fun arg i = Vector.sub (args, i) + val loopBodyVar = arg 0 + val unrolled = arg 1 val regular = arg 2 - val func = varExpInfo regular + + val loopBodyInfo = varExpInfo loopBodyVar + val loopBodySize = + let + val v = #value loopBodyInfo + fun extractLambda (v: Value.t): Slambda.t option = + case Value.dest v of + Value.Lambdas ls => + (case Lambdas.toList ls of + l :: _ => SOME (Value.Lambda.dest l) + | [] => NONE) + | _ => NONE + in + case extractLambda v of + SOME l => + let + val size = sxmlLambdaSize l + val _ = Control.diagnostics + (fn display => + display (let open Layout + in seq [str "spork_choose loop body size: ", + Int.layout size] + end)) + val _ = analyzeLambdaBody l + in + size + end + | NONE => + let + val _ = Control.diagnostics + (fn display => + display (let open Layout + in seq [str "spork_choose: no lambda found for ", + SvarExp.layout loopBodyVar] + end)) + in + 0 + end + end + + (* choose between unrolled and regular based on loop body size *) + val threshold = 100 + val chosenImpl = if loopBodySize <= threshold then unrolled else regular + val _ = Control.diagnostics + (fn display => + display (let open Layout + in seq [str " Decision: using ", + str (if loopBodySize <= threshold then "UNROLLED" else "REGULAR"), + str " (threshold=", Int.layout threshold, str ")"] + end)) + + (* apply the chosen implementation to unit *) + val func = varExpInfo chosenImpl val funcVal = VarInfo.value func - val unitExp = Dexp.tuple {exps = Vector.new0 (), - ty = Type.tuple (Vector.new0 ())} - val unitVal = Value.tuple (Vector.new0 ()) - val ty_result = valueType v + (* unit value *) + val unitTy = Type.tuple (Vector.new0 ()) + val unitExp = Dexp.tuple {exps = Vector.new0 (), ty = unitTy} + val unitVal = Value.fromType (Stype.tuple (Vector.new0 ())) val {cons, ...} = valueLambdasInfo funcVal in - (* Generate application of regular to unit, similar to apply function *) + (* ! heavy code duplication from apply *) (Dexp.casee {test = convertVarInfo func, - ty = ty_result, + ty = ty, default = NONE, cases = Dexp.Con @@ -1541,4 +1698,4 @@ fun closureConvert program end -end +end \ No newline at end of file From b6e3a8c52786f704aa122bf7608fc60db5abf254 Mon Sep 17 00:00:00 2001 From: Sundara Vishnu Satish Date: Wed, 29 Oct 2025 19:31:47 -0400 Subject: [PATCH 11/14] fix: critical blunder in spork_choose prim type --- mlton/atoms/prim.fun | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlton/atoms/prim.fun b/mlton/atoms/prim.fun index 6cd58d496..35ecdcaec 100644 --- a/mlton/atoms/prim.fun +++ b/mlton/atoms/prim.fun @@ -872,7 +872,7 @@ val kind: 'a t -> Kind.t = | MLton_touch => SideEffect | Spork _ => SideEffect (* TODO: Check usage properly *) - | Spork_choose => Functional + | Spork_choose => SideEffect | Spork_forkThreadAndSetData _ => SideEffect | Spork_getData _ => DependsOnState | Real_Math_acos _ => DependsOnState (* depends on rounding mode *) From 693ef459d5079b91140457d4bd746d2d9913f3f7 Mon Sep 17 00:00:00 2001 From: Sundara Vishnu Satish Date: Wed, 5 Nov 2025 14:07:13 -0500 Subject: [PATCH 12/14] feat: static precision check, interface cleanup - now, the integer precision check should be done statically using the WordImpl and ManagedLoops using `Int_ChooseFromInt` - abstracted out `LoopIndex` part into different file and added it to sources --- basis-library/schedulers/spork/ForkJoin.sml | 212 ++++++------------ basis-library/schedulers/spork/LoopIndex.sml | 58 +++++ .../schedulers/spork/UnrolledLoops.sml | 76 ++++++- basis-library/schedulers/spork/sources.mlb | 1 + mlton/closure-convert/closure-convert.fun | 6 +- 5 files changed, 199 insertions(+), 154 deletions(-) create mode 100644 basis-library/schedulers/spork/LoopIndex.sml diff --git a/basis-library/schedulers/spork/ForkJoin.sml b/basis-library/schedulers/spork/ForkJoin.sml index 766db6b1c..8f7181047 100644 --- a/basis-library/schedulers/spork/ForkJoin.sml +++ b/basis-library/schedulers/spork/ForkJoin.sml @@ -55,20 +55,6 @@ struct end -signature LOOP_INDEX = -sig - type idx - type t = idx - - val fromInt: int -> idx - val toInt: idx -> int - - val increment: idx -> idx - val midpoint: idx * idx -> idx - val equal: idx * idx -> bool -end - - functor ManagedLoops (LoopIndex: LOOP_INDEX) :> sig val pareduce: (int * int) -> 'a -> (int * 'a -> 'a) -> ('a * 'a -> 'a) -> 'a @@ -164,52 +150,6 @@ struct end -functor LoopIndexFromWord(WordImpl: WORD) :> LOOP_INDEX = -struct - type idx = WordImpl.word - type t = idx - - fun __inline_always__ toInt (w: idx) = __inline_always__ WordImpl.toIntX w - fun __inline_always__ fromInt i = __inline_always__ WordImpl.fromInt i - - fun __inline_always__ midpoint (i: idx, j: idx) = - let - (* This way is broken! *) - (* val mid = WordImpl.~>> (WordImpl.+ (i, j), 0w1) *) - - val range_size = WordImpl.+ (j, WordImpl.~ i) - val mid = WordImpl.+ (i, WordImpl.div (range_size, WordImpl.fromInt 2)) - in - (* If using a different midpoint calculation, consider uncommenting - * the following for debugging/testing. - *) - - (* if toInt i <= toInt mid andalso toInt mid <= toInt j then - () - else - ( print - ( "ERROR: schedulers/spork/ForkJoin.sml: bug! midpoint failure: " - ^ Int.toString (toInt i) - ^ " " - ^ Int.toString (toInt mid) - ^ " " - ^ Int.toString (toInt j) - ^ "\n" - ) - - ; OS.Process.exit OS.Process.failure - ); *) - - mid - end - - fun __inline_always__ increment (i: idx) = - WordImpl.+ (i, fromInt 1) - - fun __inline_always__ equal (i: idx, j: idx) = (i = j) -end - - structure ForkJoin :> sig datatype TokenPolicy = datatype Scheduler.TokenPolicy @@ -222,7 +162,10 @@ sig val pareduceBreakExn: (int * int) -> 'a -> (('a -> exn) * int * 'a -> 'a) -> ('a * 'a -> 'a) -> 'a val reducem: ('a * 'a -> 'a) -> 'a -> (int * int) -> (int -> 'a) -> 'a + val reduce: ('a * 'a -> 'a) -> 'a -> (int * int) -> (int -> 'a) -> 'a + val reducemDefault: ('a * 'a -> 'a) -> 'a -> (int * int) -> (int -> 'a) -> 'a val parform: (int * int) -> (int -> unit) -> unit + val parformDefault: (int * int) -> (int -> unit) -> unit val parfor: int -> (int * int) -> (int -> unit) -> unit val alloc: int -> 'a array @@ -255,7 +198,6 @@ struct val equal = op= end) - (* TODO: Need to improve this interface *) structure Unrolled8 = UnrolledLoops(Word8) structure Unrolled16 = UnrolledLoops(Word16) structure Unrolled32 = UnrolledLoops(Word32) @@ -301,102 +243,86 @@ struct val fIntInf = LoopsInt.parform end) + structure UnrolledPareduce = + Int_ChooseFromInt (struct + type 'a t = (int * int) -> 'a -> (int * 'a -> 'a) -> ('a * 'a -> 'a) -> 'a + val fInt8 = Unrolled8.pareduce + val fInt16 = Unrolled16.pareduce + val fInt32 = Unrolled32.pareduce + val fInt64 = Unrolled64.pareduce + val fIntInf = Unrolled64.pareduce + end) + + structure UnrolledPareduceBreakExn = + Int_ChooseFromInt (struct + type 'a t = (int * int) -> 'a -> (('a -> exn) * int * 'a -> 'a) -> ('a * 'a -> 'a) -> 'a + val fInt8 = Unrolled8.pareduceBreakExn + val fInt16 = Unrolled16.pareduceBreakExn + val fInt32 = Unrolled32.pareduceBreakExn + val fInt64 = Unrolled64.pareduceBreakExn + val fIntInf = Unrolled64.pareduceBreakExn + end) + + structure UnrolledReducem = + Int_ChooseFromInt (struct + type 'a t = ('a * 'a -> 'a) -> 'a -> (int * int) -> (int -> 'a) -> 'a + val fInt8 = Unrolled8.reducem + val fInt16 = Unrolled16.reducem + val fInt32 = Unrolled32.reducem + val fInt64 = Unrolled64.reducem + val fIntInf = Unrolled64.reducem + end) + + structure UnrolledParform = + Int_ChooseFromInt (struct + type 'a t = (int * int) -> (int -> unit) -> unit + val fInt8 = Unrolled8.parform + val fInt16 = Unrolled16.parform + val fInt32 = Unrolled32.parform + val fInt64 = Unrolled64.parform + val fIntInf = Unrolled64.parform + end) + local - fun unifiedReducem (combine: 'a * 'a -> 'a) (zero: 'a) (lo: int, hi: int) (f: int -> 'a) : 'a = - let - fun __inline_always__ regularImpl () = - let - val pareduce = case Int.precision of - SOME 8 => Loops8.pareduce - | SOME 16 => Loops16.pareduce - | SOME 32 => Loops32.pareduce - | SOME 64 => Loops64.pareduce - | _ => LoopsInt.pareduce - in - pareduce (lo, hi) zero (fn (i, a) => combine (a, f i)) combine - end - fun __inline_always__ unrolledImpl () = - let - val pareduce = case Int.precision of - SOME 8 => Unrolled8.pareduce - | SOME 16 => Unrolled16.pareduce - | SOME 32 => Unrolled32.pareduce - | SOME 64 => Unrolled64.pareduce - | _ => Unrolled64.pareduce (* fallback to 64-bit for IntInf *) - in - __inline_always__ pareduce (lo, hi) zero (fn (i, a) => combine (a, f i)) combine - end - in - primSporkChoose (__inline_always__ f, __inline_always__ unrolledImpl, __inline_always__ regularImpl) - end + fun __inline_always__ unifiedReducem (combine: 'a * 'a -> 'a) (zero: 'a) (lo: int, hi: int) (f: int -> 'a) : 'a = + let + fun __inline_always__ regularImpl () = __inline_always__ Reducem.f combine zero (lo, hi) f + fun __inline_always__ unrolledImpl () = __inline_always__ UnrolledReducem.f combine zero (lo, hi) f + in + primSporkChoose (__inline_always__ f, __inline_always__ unrolledImpl, __inline_always__ regularImpl) + end fun unifiedParform (lo: int, hi: int) (f: int -> unit) : unit = - let - fun __inline_always__ regularImpl () = - let - val parform = case Int.precision of - SOME 8 => Loops8.parform - | SOME 16 => Loops16.parform - | SOME 32 => Loops32.parform - | SOME 64 => Loops64.parform - | _ => LoopsInt.parform - in - parform (lo, hi) f - end - - fun __inline_always__ unrolledImpl () = - let - val pareduce = case Int.precision of - SOME 8 => Unrolled8.pareduce - | SOME 16 => Unrolled16.pareduce - | SOME 32 => Unrolled32.pareduce - | SOME 64 => Unrolled64.pareduce - | _ => Unrolled64.pareduce (* fallback to 64-bit for IntInf *) - in - __inline_always__ pareduce (lo, hi) () (fn (i, _) => f i) (fn _ => ()) - end + let + fun __inline_always__ regularImpl () = __inline_always__ Parform.f (lo, hi) f - in - primSporkChoose (__inline_always__ f, __inline_always__ unrolledImpl, __inline_always__ regularImpl) - end + fun __inline_always__ unrolledImpl () = __inline_always__ UnrolledParform.f (lo, hi) f + in + primSporkChoose (__inline_always__ f, __inline_always__ unrolledImpl, __inline_always__ regularImpl) + end fun __inline_always__ unifiedPareduce (lo: int, hi: int) (zero: 'a) (step: int * 'a -> 'a) (combine: 'a * 'a -> 'a) : 'a = let fun __inline_always__ regularImpl () = - let - val pareduce = case Int.precision of - SOME 8 => Loops8.pareduce - | SOME 16 => Loops16.pareduce - | SOME 32 => Loops32.pareduce - | SOME 64 => Loops64.pareduce - | _ => LoopsInt.pareduce - in - pareduce (lo, hi) zero step combine - end + __inline_always__ Pareduce.f (lo, hi) zero step combine fun __inline_always__ unrolledImpl () = - let - val pareduce = case Int.precision of - SOME 8 => Unrolled8.pareduce - | SOME 16 => Unrolled16.pareduce - | SOME 32 => Unrolled32.pareduce - | SOME 64 => Unrolled64.pareduce - | _ => Unrolled64.pareduce (* fallback to 64-bit for IntInf *) - in - __inline_always__ pareduce (lo, hi) zero step combine - end - - fun __inline_always__ loopBody i = step (i, zero) + __inline_always__ UnrolledPareduce.f (lo, hi) zero step combine + + fun __inline_always__ loopBody i = __inline_always__ step (i, zero) in primSporkChoose (__inline_always__ loopBody, __inline_always__ unrolledImpl, __inline_always__ regularImpl) end in - val reducem = unifiedReducem - val parform = unifiedParform - val pareduce = unifiedPareduce - val parfor = ForkJoin0.parfor + val reducem = __inline_always__ unifiedReducem + val reduce = __inline_always__ unifiedReducem + val reducemDefault = __inline_always__ Reducem.f + val parform = __inline_always__ unifiedParform + val parformDefault = __inline_always__ Parform.f + val pareduce = __inline_always__ unifiedPareduce + val parfor = __inline_always__ ForkJoin0.parfor end - val pareduceBreakExn = PareduceBreakExn.f -end + val pareduceBreakExn = __inline_always__ PareduceBreakExn.f +end \ No newline at end of file diff --git a/basis-library/schedulers/spork/LoopIndex.sml b/basis-library/schedulers/spork/LoopIndex.sml new file mode 100644 index 000000000..1d09ea641 --- /dev/null +++ b/basis-library/schedulers/spork/LoopIndex.sml @@ -0,0 +1,58 @@ +signature LOOP_INDEX = +sig + type idx + type t = idx + + val fromInt: int -> idx + val toInt: idx -> int + + val increment: idx -> idx + val midpoint: idx * idx -> idx + val equal: idx * idx -> bool +end + + +functor LoopIndexFromWord(WordImpl: WORD) :> LOOP_INDEX = +struct + type idx = WordImpl.word + type t = idx + + fun __inline_always__ toInt (w: idx) = __inline_always__ WordImpl.toIntX w + fun __inline_always__ fromInt i = __inline_always__ WordImpl.fromInt i + + fun __inline_always__ midpoint (i: idx, j: idx) = + let + (* This way is broken! *) + (* val mid = WordImpl.~>> (WordImpl.+ (i, j), 0w1) *) + + val range_size = WordImpl.+ (j, WordImpl.~ i) + val mid = WordImpl.+ (i, WordImpl.div (range_size, WordImpl.fromInt 2)) + in + (* If using a different midpoint calculation, consider uncommenting + * the following for debugging/testing. + *) + + (* if toInt i <= toInt mid andalso toInt mid <= toInt j then + () + else + ( print + ( "ERROR: schedulers/spork/ForkJoin.sml: bug! midpoint failure: " + ^ Int.toString (toInt i) + ^ " " + ^ Int.toString (toInt mid) + ^ " " + ^ Int.toString (toInt j) + ^ "\n" + ) + + ; OS.Process.exit OS.Process.failure + ); *) + + mid + end + + fun __inline_always__ increment (i: idx) = + WordImpl.+ (i, fromInt 1) + + fun __inline_always__ equal (i: idx, j: idx) = (i = j) +end diff --git a/basis-library/schedulers/spork/UnrolledLoops.sml b/basis-library/schedulers/spork/UnrolledLoops.sml index fa5944fa4..829e3af85 100644 --- a/basis-library/schedulers/spork/UnrolledLoops.sml +++ b/basis-library/schedulers/spork/UnrolledLoops.sml @@ -1,18 +1,24 @@ -functor UnrolledLoops(WordImpl: WORD) = +functor UnrolledLoops(WordImpl: WORD) :> +sig + val pareduce: (int * int) -> 'a -> (int * 'a -> 'a) -> ('a * 'a -> 'a) -> 'a + val pareduceBreakExn: (int * int) -> 'a -> (('a -> exn) * int * 'a -> 'a) -> ('a * 'a -> 'a) -> 'a + val reducem: ('a * 'a -> 'a) -> 'a -> (int * int) -> (int -> 'a) -> 'a + val parform: (int * int) -> (int -> unit) -> unit +end = struct type word = WordImpl.word fun __inline_always__ w2i w = __inline_always__ WordImpl.toIntX w fun __inline_always__ i2w i = __inline_always__ WordImpl.fromInt i - val one = __inline_always__ WordImpl.fromInt 1 - val two = __inline_always__ WordImpl.fromInt 2 - val three = __inline_always__ WordImpl.fromInt 3 - val four = __inline_always__ WordImpl.fromInt 4 - val five = __inline_always__ WordImpl.fromInt 5 - val six = __inline_always__ WordImpl.fromInt 6 - val seven = __inline_always__ WordImpl.fromInt 7 - val eight = __inline_always__ WordImpl.fromInt 8 + val one = i2w 1 + val two = i2w 2 + val three = i2w 3 + val four = i2w 4 + val five = i2w 5 + val six = i2w 6 + val seven = i2w 7 + val eight = i2w 8 fun __inline_always__ midpoint (i: word, j: word) = WordImpl.+ (i, WordImpl.>> (WordImpl.- (j, i), 0w1)) @@ -193,4 +199,56 @@ struct loop1 (z, i2w lo, i2w hi) end + + fun __inline_always__ pareduceBreakExn (i: int, j: int) (z: 'a) (step: ('a -> exn) * int * 'a -> 'a) (merge: 'a * 'a -> 'a): 'a = + let exception Break of 'a + fun step' (i, a) = (__inline_always__ step (Break, i, a), true) handle (Break b) => (b, false) + fun merge' ((b1, cont1), (b2, cont2)) = + if cont1 then (merge (b1, b2), cont2) else (b1, false) + + (* we can reuse the pareduce structure but need to adapt it for break semantics *) + (* for simplicity, we'll use a basic implementation that wraps pareduce *) + (* a more optimized version would inline the break logic into the unrolled loops *) + + fun continue (f : 'a -> 'a * bool) : 'a * bool -> 'a * bool = + fn (b, cont) => if cont then f b else (b, cont) + + fun iter (b: 'a) (i: word, j: word): 'a * bool = + if i = j then (b, true) else + let + fun __inline_never__ spwn b' = + if WordImpl.>= (WordImpl.+ (i, one), j) then (b', true) else + let val mid = midpoint (WordImpl.+ (i, one), j) in + Scheduler.SporkJoin.spork { + tokenPolicy = Scheduler.TokenPolicyFair, + body = fn () => iter b' (WordImpl.+ (i, one), mid), + spwn = fn () => iter z (mid, j), + seq = continue (fn b' => iter b' (mid, j)), + sync = merge', + unstolen = NONE + } + end + in + Scheduler.SporkJoin.spork { + tokenPolicy = Scheduler.TokenPolicyGive, + body = fn () => __inline_always__ step' (w2i i, b), + spwn = fn () => spwn z, + seq = continue (fn b' => iter b' (WordImpl.+ (i, one), j)), + sync = merge', + unstolen = SOME (continue spwn) + } + end + val (result, cont) = __inline_always__ iter z (i2w (Int.min (i, j)), i2w j) + in + result + end + + + fun __inline_always__ reducem g z (lo, hi) f = + pareduce (lo, hi) z (fn (i, a) => __inline_always__ g (a, __inline_always__ f i)) g + + + fun __inline_always__ parform (lo: int, hi: int) (f: int -> unit) : unit = + pareduce (lo, hi) () (fn (i, _) => f i) (fn _ => ()) + end \ No newline at end of file diff --git a/basis-library/schedulers/spork/sources.mlb b/basis-library/schedulers/spork/sources.mlb index d168d0cc0..6c380c54c 100644 --- a/basis-library/schedulers/spork/sources.mlb +++ b/basis-library/schedulers/spork/sources.mlb @@ -38,6 +38,7 @@ local in Heartbeat.sml Scheduler.sml + LoopIndex.sml UnrolledLoops.sml end ForkJoin.sml diff --git a/mlton/closure-convert/closure-convert.fun b/mlton/closure-convert/closure-convert.fun index 3167545b9..e5ff9ff4d 100644 --- a/mlton/closure-convert/closure-convert.fun +++ b/mlton/closure-convert/closure-convert.fun @@ -573,10 +573,12 @@ fun closureConvert let fun arg i = Vector.sub (args, i) val regular = arg 2 + val unrolled = arg 1 val unitTy = Stype.tuple (Vector.new0 ()) val unitArg = Value.fromType unitTy val result = new () val _ = Value.addHandler + (* !TODO *) (varExp regular, fn l => let val lambda = Value.Lambda.dest l @@ -1430,8 +1432,8 @@ fun closureConvert val func = varExpInfo chosenImpl val funcVal = VarInfo.value func (* unit value *) - val unitTy = Type.tuple (Vector.new0 ()) - val unitExp = Dexp.tuple {exps = Vector.new0 (), ty = unitTy} + val unitTy = Type.tuple (Vector.new0 ()) (* TODO: refactor as SType.unit *) + val unitExp = Dexp.tuple {exps = Vector.new0 (), ty = unitTy} (* ! Dexp.unit = val unit = Tuple {exps = Vector.new0 (), ty = Type.unit} *) val unitVal = Value.fromType (Stype.tuple (Vector.new0 ())) val {cons, ...} = valueLambdasInfo funcVal in From e4a0603138a94f4837bb37bb9129ba88cfa215b0 Mon Sep 17 00:00:00 2001 From: Sundara Vishnu Satish Date: Tue, 18 Nov 2025 13:39:43 -0500 Subject: [PATCH 13/14] feat: parameterize (compile-time) spork_choose threshold --- mlton/closure-convert/closure-convert.fun | 2 +- mlton/control/control-flags.sig | 1 + mlton/control/control-flags.sml | 5 +++++ mlton/main/main.fun | 6 ++++++ 4 files changed, 13 insertions(+), 1 deletion(-) diff --git a/mlton/closure-convert/closure-convert.fun b/mlton/closure-convert/closure-convert.fun index e5ff9ff4d..ec297eb04 100644 --- a/mlton/closure-convert/closure-convert.fun +++ b/mlton/closure-convert/closure-convert.fun @@ -1418,7 +1418,7 @@ fun closureConvert end (* choose between unrolled and regular based on loop body size *) - val threshold = 100 + val threshold = !Control.sporkChooseThreshold val chosenImpl = if loopBodySize <= threshold then unrolled else regular val _ = Control.diagnostics (fn display => diff --git a/mlton/control/control-flags.sig b/mlton/control/control-flags.sig index 8b0404160..5d23804b2 100644 --- a/mlton/control/control-flags.sig +++ b/mlton/control/control-flags.sig @@ -69,6 +69,7 @@ signature CONTROL_FLAGS = val closureConvertGlobalize: bool ref val closureConvertShrink: bool ref + val sporkChooseThreshold: int ref structure Codegen: sig diff --git a/mlton/control/control-flags.sml b/mlton/control/control-flags.sml index d4f002de4..b30403b10 100644 --- a/mlton/control/control-flags.sml +++ b/mlton/control/control-flags.sml @@ -230,6 +230,11 @@ val closureConvertShrink = control {name = "closureConvertShrink", default = true, toString = Bool.toString} +val sporkChooseThreshold: int ref = + control {name = "spork choose threshold", + default = 100, + toString = Int.toString} + structure Codegen = struct datatype t = diff --git a/mlton/main/main.fun b/mlton/main/main.fun index 9e43abe2d..6a5777187 100644 --- a/mlton/main/main.fun +++ b/mlton/main/main.fun @@ -290,6 +290,12 @@ fun makeOptions {usage} = (Expert, "closure-convert-shrink", " {true|false}", "whether to shrink during closure conversion", Bool (fn b => (closureConvertShrink := b))), + (Expert, "spork-choose-threshold", " ", + "threshold for spork_choose loop body size decision, default = 100", + Int (fn n => + if n < 0 + then usage "spork-choose-threshold must be non-negative" + else sporkChooseThreshold := n)), (Normal, "codegen", concat [" {", String.concatWith From e13ca2fa27fd953923facebc0a09b1e4375b9d8f Mon Sep 17 00:00:00 2001 From: Sundara Vishnu Satish Date: Tue, 18 Nov 2025 14:36:00 -0500 Subject: [PATCH 14/14] feat: loopChoose wip, untested --- basis-library/schedulers/spork/ForkJoin.sml | 85 ++++++++++++++++- basis-library/schedulers/spork/Scheduler.sml | 8 ++ .../schedulers/spork/UnrolledLoops.sml | 60 ++++++++++++ mlton/atoms/prim.fun | 23 +++++ mlton/atoms/prim.sig | 1 + mlton/closure-convert/closure-convert.fun | 94 +++++++++++++++++++ 6 files changed, 270 insertions(+), 1 deletion(-) diff --git a/basis-library/schedulers/spork/ForkJoin.sml b/basis-library/schedulers/spork/ForkJoin.sml index 8f7181047..93d28c0bb 100644 --- a/basis-library/schedulers/spork/ForkJoin.sml +++ b/basis-library/schedulers/spork/ForkJoin.sml @@ -61,6 +61,8 @@ sig val pareduceBreakExn: (int * int) -> 'a -> (('a -> exn) * int * 'a -> 'a) -> ('a * 'a -> 'a) -> 'a val reducem: ('a * 'a -> 'a) -> 'a -> (int * int) -> (int -> 'a) -> 'a val parform: (int * int) -> (int -> unit) -> unit + val seqLoop: (int * int) -> (int -> unit) -> unit + val seqReduce: ('a * 'a -> 'a) -> 'a -> (int * int) -> (int -> 'a) -> 'a end = struct type idx = LoopIndex.t @@ -147,6 +149,26 @@ struct fun __inline_always__ parform (lo: int, hi: int) (f: int -> unit) : unit = reducem (fn _ => ()) () (lo, hi) f + + + fun __inline_always__ seqLoop (lo: int, hi: int) (f: int -> unit) : unit = + let + fun loop (i: idx, j: idx) : unit = + if LoopIndex.equal (i, j) then () + else (__inline_always__ f (LoopIndex.toInt i); loop (LoopIndex.increment i, j)) + in + loop (LoopIndex.fromInt (Int.min (lo, hi)), LoopIndex.fromInt hi) + end + + + fun __inline_always__ seqReduce (combine: 'a * 'a -> 'a) (zero: 'a) (lo: int, hi: int) (f: int -> 'a) : 'a = + let + fun loop (acc: 'a) (i: idx, j: idx) : 'a = + if LoopIndex.equal (i, j) then acc + else loop (__inline_always__ combine (acc, __inline_always__ f (LoopIndex.toInt i))) (LoopIndex.increment i, j) + in + loop zero (LoopIndex.fromInt (Int.min (lo, hi)), LoopIndex.fromInt hi) + end end @@ -170,6 +192,9 @@ sig val parfor: int -> (int * int) -> (int -> unit) -> unit val alloc: int -> 'a array + val seqLoop: (int * int) -> (int -> unit) -> unit + val seqReduce: ('a * 'a -> 'a) -> 'a -> (int * int) -> (int -> 'a) -> 'a + val idleTimeSoFar: unit -> Time.time val workTimeSoFar: unit -> Time.time val maxForkDepthSoFar: unit -> int @@ -280,7 +305,47 @@ struct val fInt16 = Unrolled16.parform val fInt32 = Unrolled32.parform val fInt64 = Unrolled64.parform - val fIntInf = Unrolled64.parform + val fIntInf = Unrolled64.parform + end) + + structure SeqLoop = + Int_ChooseFromInt (struct + type 'a t = (int * int) -> (int -> unit) -> unit + val fInt8 = Loops8.seqLoop + val fInt16 = Loops16.seqLoop + val fInt32 = Loops32.seqLoop + val fInt64 = Loops64.seqLoop + val fIntInf = LoopsInt.seqLoop + end) + + structure SeqReduce = + Int_ChooseFromInt (struct + type 'a t = ('a * 'a -> 'a) -> 'a -> (int * int) -> (int -> 'a) -> 'a + val fInt8 = Loops8.seqReduce + val fInt16 = Loops16.seqReduce + val fInt32 = Loops32.seqReduce + val fInt64 = Loops64.seqReduce + val fIntInf = LoopsInt.seqReduce + end) + + structure UnrolledSeqLoop = + Int_ChooseFromInt (struct + type 'a t = (int * int) -> (int -> unit) -> unit + val fInt8 = Unrolled8.seqLoop + val fInt16 = Unrolled16.seqLoop + val fInt32 = Unrolled32.seqLoop + val fInt64 = Unrolled64.seqLoop + val fIntInf = Unrolled64.seqLoop + end) + + structure UnrolledSeqReduce = + Int_ChooseFromInt (struct + type 'a t = ('a * 'a -> 'a) -> 'a -> (int * int) -> (int -> 'a) -> 'a + val fInt8 = Unrolled8.seqReduce + val fInt16 = Unrolled16.seqReduce + val fInt32 = Unrolled32.seqReduce + val fInt64 = Unrolled64.seqReduce + val fIntInf = Unrolled64.seqReduce end) local @@ -314,6 +379,22 @@ struct in primSporkChoose (__inline_always__ loopBody, __inline_always__ unrolledImpl, __inline_always__ regularImpl) end + + fun __inline_always__ unifiedSeqLoop (lo: int, hi: int) (f: int -> unit) : unit = + let + fun __inline_always__ regularImpl () = __inline_always__ SeqLoop.f (lo, hi) f + fun __inline_always__ unrolledImpl () = __inline_always__ UnrolledSeqLoop.f (lo, hi) f + in + Scheduler.primLoopChoose (__inline_always__ f, __inline_always__ unrolledImpl, __inline_always__ regularImpl) + end + + fun __inline_always__ unifiedSeqReduce (combine: 'a * 'a -> 'a) (zero: 'a) (lo: int, hi: int) (f: int -> 'a) : 'a = + let + fun __inline_always__ regularImpl () = __inline_always__ SeqReduce.f combine zero (lo, hi) f + fun __inline_always__ unrolledImpl () = __inline_always__ UnrolledSeqReduce.f combine zero (lo, hi) f + in + Scheduler.primLoopChoose (__inline_always__ f, __inline_always__ unrolledImpl, __inline_always__ regularImpl) + end in val reducem = __inline_always__ unifiedReducem val reduce = __inline_always__ unifiedReducem @@ -322,6 +403,8 @@ struct val parformDefault = __inline_always__ Parform.f val pareduce = __inline_always__ unifiedPareduce val parfor = __inline_always__ ForkJoin0.parfor + val seqLoop = __inline_always__ unifiedSeqLoop + val seqReduce = __inline_always__ unifiedSeqReduce end val pareduceBreakExn = __inline_always__ PareduceBreakExn.f diff --git a/basis-library/schedulers/spork/Scheduler.sml b/basis-library/schedulers/spork/Scheduler.sml index 0fa517cdb..099cf030c 100644 --- a/basis-library/schedulers/spork/Scheduler.sml +++ b/basis-library/schedulers/spork/Scheduler.sml @@ -109,6 +109,12 @@ struct * (unit -> 'a) (* unrolled implementation *) * (unit -> 'a) (* regular implementation *) -> 'a; + val primLoopChoose' = + _prim "loop_choose" + : ('u -> 'a) (* loop body *) + * (unit -> 'a) (* unrolled implementation *) + * (unit -> 'a) (* regular implementation *) + -> 'a; fun __inline_always__ primSporkFair (body, spwn, seq, sync, exnseq, exnsync) = __inline_always__ primSporkFair' (body, (), spwn, (), seq, sync, exnseq, exnsync) @@ -118,6 +124,8 @@ struct __inline_always__ primSporkGive' (body, (), spwn, (), seq, sync, exnseq, exnsync) fun __inline_always__ primSporkChoose (loopBody, unrolled, regular) = __inline_always__ primSporkChoose' (loopBody, unrolled, regular) + fun __inline_always__ primLoopChoose (loopBody, unrolled, regular) = + __inline_always__ primLoopChoose' (loopBody, unrolled, regular) val primForkThreadAndSetData = _prim "spork_forkThreadAndSetData": Thread.t * 'a -> Thread.p; val primForkThreadAndSetData_youngest = _prim "spork_forkThreadAndSetData_youngest": Thread.t * 'a -> Thread.p; diff --git a/basis-library/schedulers/spork/UnrolledLoops.sml b/basis-library/schedulers/spork/UnrolledLoops.sml index 829e3af85..460449cd4 100644 --- a/basis-library/schedulers/spork/UnrolledLoops.sml +++ b/basis-library/schedulers/spork/UnrolledLoops.sml @@ -4,6 +4,8 @@ sig val pareduceBreakExn: (int * int) -> 'a -> (('a -> exn) * int * 'a -> 'a) -> ('a * 'a -> 'a) -> 'a val reducem: ('a * 'a -> 'a) -> 'a -> (int * int) -> (int -> 'a) -> 'a val parform: (int * int) -> (int -> unit) -> unit + val seqLoop: (int * int) -> (int -> unit) -> unit + val seqReduce: ('a * 'a -> 'a) -> 'a -> (int * int) -> (int -> 'a) -> 'a end = struct @@ -251,4 +253,62 @@ struct fun __inline_always__ parform (lo: int, hi: int) (f: int -> unit) : unit = pareduce (lo, hi) () (fn (i, _) => f i) (fn _ => ()) + + fun __inline_always__ seqLoop (lo: int, hi: int) (f: int -> unit) : unit = + let + fun loop8 (i: word, j: word) : unit = + if WordImpl.<= (WordImpl.+ (i, eight), j) then + let + val _ = __inline_always__ f (w2i i) + val _ = __inline_always__ f (w2i (WordImpl.+ (i, one))) + val _ = __inline_always__ f (w2i (WordImpl.+ (i, two))) + val _ = __inline_always__ f (w2i (WordImpl.+ (i, three))) + val _ = __inline_always__ f (w2i (WordImpl.+ (i, four))) + val _ = __inline_always__ f (w2i (WordImpl.+ (i, five))) + val _ = __inline_always__ f (w2i (WordImpl.+ (i, six))) + val _ = __inline_always__ f (w2i (WordImpl.+ (i, seven))) + in + loop8 (WordImpl.+ (i, eight), j) + end + else + loop1 (i, j) + + and loop1 (i: word, j: word) : unit = + if WordImpl.< (i, j) then + (__inline_always__ f (w2i i); loop1 (WordImpl.+ (i, one), j)) + else + () + in + loop8 (i2w lo, i2w hi) + end + + + fun __inline_always__ seqReduce (combine: 'a * 'a -> 'a) (zero: 'a) (lo: int, hi: int) (f: int -> 'a) : 'a = + let + fun loop8 (acc: 'a, i: word, j: word) : 'a = + if WordImpl.<= (WordImpl.+ (i, eight), j) then + let + val acc = __inline_always__ combine (acc, __inline_always__ f (w2i i)) + val acc = __inline_always__ combine (acc, __inline_always__ f (w2i (WordImpl.+ (i, one)))) + val acc = __inline_always__ combine (acc, __inline_always__ f (w2i (WordImpl.+ (i, two)))) + val acc = __inline_always__ combine (acc, __inline_always__ f (w2i (WordImpl.+ (i, three)))) + val acc = __inline_always__ combine (acc, __inline_always__ f (w2i (WordImpl.+ (i, four)))) + val acc = __inline_always__ combine (acc, __inline_always__ f (w2i (WordImpl.+ (i, five)))) + val acc = __inline_always__ combine (acc, __inline_always__ f (w2i (WordImpl.+ (i, six)))) + val acc = __inline_always__ combine (acc, __inline_always__ f (w2i (WordImpl.+ (i, seven)))) + in + loop8 (acc, WordImpl.+ (i, eight), j) + end + else + loop1 (acc, i, j) + + and loop1 (acc: 'a, i: word, j: word) : 'a = + if WordImpl.< (i, j) then + loop1 (__inline_always__ combine (acc, __inline_always__ f (w2i i)), WordImpl.+ (i, one), j) + else + acc + in + loop8 (zero, i2w lo, i2w hi) + end + end \ No newline at end of file diff --git a/mlton/atoms/prim.fun b/mlton/atoms/prim.fun index 35ecdcaec..446389c44 100644 --- a/mlton/atoms/prim.fun +++ b/mlton/atoms/prim.fun @@ -113,6 +113,7 @@ datatype 'a t = | MLton_share (* to rssa (as nop or runtime C fn) *) | MLton_size (* to rssa (as runtime C fn) *) | MLton_touch (* to rssa (as nop) or backend (as nop) *) + | Loop_choose (* closure convert *) (* Choose between unrolled and regular at compile time *) | Spork_choose (* closure convert *) | Spork of {tokenSplitPolicy: Word32.word} (* closure convert *) @@ -296,6 +297,7 @@ fun toString (n: 'a t): string = | MLton_share => "MLton_share" | MLton_size => "MLton_size" | MLton_touch => "MLton_touch" + | Loop_choose => "loop_choose" | Spork_choose => "spork_choose" | Spork {tokenSplitPolicy=0w0} => "spork_fair" | Spork {tokenSplitPolicy=0w1} => "spork_keep" @@ -465,6 +467,7 @@ val equals: 'a t * 'a t -> bool = | (MLton_share, MLton_share) => true | (MLton_size, MLton_size) => true | (MLton_touch, MLton_touch) => true + | (Loop_choose, Loop_choose) => true (* TODO: Check usage properly *) | (Spork_choose, Spork_choose) => true | (Spork {tokenSplitPolicy = tsp1}, Spork {tokenSplitPolicy = tsp2}) => tsp1 = tsp2 @@ -654,6 +657,7 @@ val map: 'a t * ('a -> 'b) -> 'b t = | MLton_touch => MLton_touch | Spork tsp => Spork tsp (* TODO: Check usage properly *) + | Loop_choose => Loop_choose | Spork_choose => Spork_choose | Spork_forkThreadAndSetData z => Spork_forkThreadAndSetData z | Spork_getData spid => Spork_getData spid @@ -872,6 +876,7 @@ val kind: 'a t -> Kind.t = | MLton_touch => SideEffect | Spork _ => SideEffect (* TODO: Check usage properly *) + | Loop_choose => SideEffect | Spork_choose => SideEffect | Spork_forkThreadAndSetData _ => SideEffect | Spork_getData _ => DependsOnState @@ -1087,6 +1092,7 @@ in Spork {tokenSplitPolicy = 0w1}, Spork {tokenSplitPolicy = 0w2}, (* TODO: Check usage properly *) + Loop_choose, Spork_choose, Spork_forkThreadAndSetData {youngest=true}, Spork_forkThreadAndSetData {youngest=false}, @@ -1455,6 +1461,15 @@ fun 'a checkApp (prim: 'a t, in (eightArgs (cont, taa, spwn, tba, seq, sync, exnseq, exnsync), tc) end) + | Loop_choose => + (* TODO: Check usage properly *) + twoTargs (fn (ta, tu) => + let + val loopBody = arrow (tu, ta) (* First arg: loop body function 'u -> 'a *) + val impl = arrow (unit, ta) (* Second and third args: thunks unit -> 'a *) + in + (threeArgs (loopBody, impl, impl), ta) + end) | Spork_choose => (* TODO: Check usage properly *) (* spork_choose: ('u -> 'v) -> (unit -> 'a) -> (unit -> 'a) -> 'a @@ -1621,6 +1636,14 @@ fun ('a, 'b) extractTargs (prim: 'b t, in six (taa, tar, tba, tbr, td, tc) end + | Loop_choose => + (* TODO: Check usage properly *) + let + val ta = result (* Result type 'a *) + val (tu, _) = deArrow (arg 0) (* First arg: loop body ('u -> 'v) *) + in + Vector.new2 (ta, tu) + end | Spork_choose => (* TODO: Check usage properly *) (* spork_choose: ('u -> 'v) -> 'a -> 'a -> 'a *) diff --git a/mlton/atoms/prim.sig b/mlton/atoms/prim.sig index 9188455d4..5c4b47117 100644 --- a/mlton/atoms/prim.sig +++ b/mlton/atoms/prim.sig @@ -104,6 +104,7 @@ signature PRIM = | MLton_share (* to rssa (as nop or runtime C fn) *) | MLton_size (* to rssa (as runtime C fn) *) | MLton_touch (* to rssa (as nop) or backend (as nop) *) + | Loop_choose (* closure convert *) | Spork_choose (* TODO: closure convert / SSA / SSA2 / RSSA ? *) | Spork of {tokenSplitPolicy: Word32.word} (* closure convert *) | Spork_forkThreadAndSetData of {youngest: bool} (* to rssa (as runtime C fn) *) diff --git a/mlton/closure-convert/closure-convert.fun b/mlton/closure-convert/closure-convert.fun index ec297eb04..d743759c7 100644 --- a/mlton/closure-convert/closure-convert.fun +++ b/mlton/closure-convert/closure-convert.fun @@ -1464,6 +1464,100 @@ fun closureConvert end))}, ac) end + | SprimExp.PrimApp {prim = Prim.Loop_choose, targs, args} => + (* loop_choose: ('u -> 'a) -> (unit -> 'a) -> (unit -> 'a) -> 'a *) + let + fun arg i = Vector.sub (args, i) + val loopBodyVar = arg 0 + val unrolled = arg 1 + val regular = arg 2 + + val loopBodyInfo = varExpInfo loopBodyVar + val loopBodySize = + let + val v = #value loopBodyInfo + fun extractLambda (v: Value.t): Slambda.t option = + case Value.dest v of + Value.Lambdas ls => + (case Lambdas.toList ls of + l :: _ => SOME (Value.Lambda.dest l) + | [] => NONE) + | _ => NONE + in + case extractLambda v of + SOME l => + let + val size = sxmlLambdaSize l + val _ = Control.diagnostics + (fn display => + display (let open Layout + in seq [str "loop_choose loop body size: ", + Int.layout size] + end)) + val _ = analyzeLambdaBody l + in + size + end + | NONE => + let + val _ = Control.diagnostics + (fn display => + display (let open Layout + in seq [str "loop_choose: no lambda found for ", + SvarExp.layout loopBodyVar] + end)) + in + 0 + end + end + + (* choose between unrolled and regular based on loop body size *) + val threshold = 100 (* hardcoded for now *) + val chosenImpl = if loopBodySize <= threshold then unrolled else regular + val _ = Control.diagnostics + (fn display => + display (let open Layout + in seq [str " Decision: using ", + str (if loopBodySize <= threshold then "UNROLLED" else "REGULAR"), + str " (threshold=100)"] + end)) + + (* apply the chosen implementation to unit *) + val func = varExpInfo chosenImpl + val funcVal = VarInfo.value func + (* unit value *) + val unitTy = Type.tuple (Vector.new0 ()) + val unitExp = Dexp.tuple {exps = Vector.new0 (), ty = unitTy} + val unitVal = Value.fromType (Stype.tuple (Vector.new0 ())) + val {cons, ...} = valueLambdasInfo funcVal + in + (* ! heavy code duplication from apply *) + (Dexp.casee + {test = convertVarInfo func, + ty = ty, + default = NONE, + cases = + Dexp.Con + (Vector.map + (cons, fn {lambda, con} => + let + val {arg = param, body, ...} = Slambda.dest lambda + val info as LambdaInfo.T {name, ...} = lambdaInfo lambda + val result = expValue body + val env = (Var.newString "env", lambdaInfoType info) + in {con = con, + args = Vector.new1 env, + body = coerce (Dexp.call + {func = name, + args = Vector.new2 (Dexp.var env, + coerce (unitExp, unitVal, + value param)), + inline = InlineAttr.Auto, + ty = valueType result}, + result, v)} + end))}, + ac) + end | SprimExp.PrimApp {prim, targs, args} => let val prim = Prim.map (prim, convertType)