diff --git a/configs/common/CacheConfig.py b/configs/common/CacheConfig.py index 3adf07fe8c..0c98e50449 100644 --- a/configs/common/CacheConfig.py +++ b/configs/common/CacheConfig.py @@ -72,6 +72,28 @@ def _get_cache_opts(cpu, level, options): return opts +def apply_matrix_timing_options(cpu, options): + timing_attrs = [ + "matrix_issue_interval_cycles", + "matrix_load_base_cycles", + "matrix_store_base_cycles", + "matrix_zero_cycles", + "matrix_compute_base_cycles", + "matrix_compute_read_cycles", + "matrix_release_cycles", + "matrix_local_mmu_issue_per_cycle", + "matrix_local_mmu_arb_cycles", + "matrix_l2_request_pipeline_cycles", + "matrix_l2_response_pipeline_cycles", + "matrix_local_mmu_read_latency_cycles", + "matrix_local_mmu_write_ack_latency_cycles", + ] + for isa in getattr(cpu, "isa", []): + for attr in timing_attrs: + value = getattr(options, attr, None) + if value is not None and hasattr(isa, attr): + setattr(isa, attr, value) + def config_classic_l2(options, system, l2_cache_class): # When using classic L2 cache, The prefetcher is inside the l2cache, instead of l2Wrapper # So we need to move the prefetcher from l2Wrapper to l2cache @@ -298,6 +320,8 @@ def config_cache(options, system): system.memchecker = MemChecker() for i in range(options.num_cpus): + apply_matrix_timing_options(system.cpu[i], options) + if options.caches: icache = icache_class(**_get_cache_opts(system.cpu[i], 'l1i', options)) dcache = dcache_class(**_get_cache_opts(system.cpu[i], 'l1d', options)) diff --git a/configs/common/Options.py b/configs/common/Options.py index 441840b5c8..964567b854 100644 --- a/configs/common/Options.py +++ b/configs/common/Options.py @@ -334,6 +334,43 @@ def addCommonOptions(parser, configure_xiangshan=False): parser.add_argument("--ideal-kmhv3", action= "store_true", help="Use KunminghuV3 ideal params, which take priority over command-line arguments.") + # Coarse CUTE matrix timing knobs. These are intentionally behavior-level + # controls for sensitivity studies, not cycle-accurate RTL stage controls. + parser.add_argument("--matrix-issue-interval-cycles", type=int, + default=None, + help="CUTE matrix task issue interval in CPU cycles") + parser.add_argument("--matrix-load-base-cycles", type=int, default=None, + help="CUTE matrix load base latency in CPU cycles") + parser.add_argument("--matrix-store-base-cycles", type=int, default=None, + help="CUTE matrix store base latency in CPU cycles") + parser.add_argument("--matrix-zero-cycles", type=int, default=None, + help="CUTE matrix zero latency in CPU cycles") + parser.add_argument("--matrix-compute-base-cycles", type=int, + default=None, + help="Fixed CUTE matrix compute ready latency in CPU cycles") + parser.add_argument("--matrix-compute-read-cycles", type=int, + default=None, + help="Fixed CUTE matrix compute source read latency in CPU cycles") + parser.add_argument("--matrix-release-cycles", type=int, default=None, + help="CUTE matrix release latency in CPU cycles") + parser.add_argument("--matrix-local-mmu-issue-per-cycle", type=int, + default=None, + help="CUTE LocalMMU request issue throughput per CPU cycle") + parser.add_argument("--matrix-local-mmu-arb-cycles", type=int, + default=None, + help="CUTE LocalMMU arbitration latency in CPU cycles") + parser.add_argument("--matrix-l2-request-pipeline-cycles", type=int, + default=None, + help="CUTE-to-L2 request pipeline latency in CPU cycles") + parser.add_argument("--matrix-l2-response-pipeline-cycles", type=int, + default=None, + help="CUTE L2 response service interval in CPU cycles") + parser.add_argument("--matrix-local-mmu-read-latency-cycles", type=int, + default=None, + help="CUTE LocalMMU read response latency in CPU cycles") + parser.add_argument("--matrix-local-mmu-write-ack-latency-cycles", + type=int, default=None, + help="CUTE LocalMMU write acknowledgement latency in CPU cycles") # for warmup without switching cpu parser.add_argument("--warmup-insts-no-switch", action="store", type=int, diff --git a/configs/example/ai_idealkmhv3.py b/configs/example/ai_idealkmhv3.py new file mode 100644 index 0000000000..34b1068456 --- /dev/null +++ b/configs/example/ai_idealkmhv3.py @@ -0,0 +1,189 @@ +import argparse +import os +import sys + +import m5 +from m5.defines import buildEnv +from m5.objects import * +from m5.util import addToPath, fatal, warn +from m5.util.fdthelper import * + +addToPath('../') + +from ruby import Ruby +from common.LSQBankConflict import set_lsq_bank_conflict_cache_params + +from common.FSConfig import * +from common.SysPaths import * +from common.Benchmarks import * +from common import Simulation +from common.Caches import * +from common.xiangshan import * + +from m5.objects.ValuePredictor import * + + +AI_MATRIX_TIMING_DEFAULTS = { + "matrix_issue_interval_cycles": 1, + "matrix_load_base_cycles": 4, + "matrix_store_base_cycles": 4, + "matrix_zero_cycles": 1, + "matrix_compute_base_cycles": 2, + "matrix_compute_read_cycles": 1, + "matrix_release_cycles": 1, + "matrix_local_mmu_issue_per_cycle": 1, + "matrix_local_mmu_arb_cycles": 1, + "matrix_l2_request_pipeline_cycles": 1, + "matrix_l2_response_pipeline_cycles": 1, + "matrix_local_mmu_read_latency_cycles": 20, + "matrix_local_mmu_write_ack_latency_cycles": 12, +} + + +def setAiMatrixTimingDefaults(args): + # Keep CUTE timing coarse by default. Command-line --matrix-* options remain + # higher priority so calibration runs can sweep these knobs directly. + for attr, value in AI_MATRIX_TIMING_DEFAULTS.items(): + if getattr(args, attr, None) is None: + setattr(args, attr, value) + + +def setAiKmhV3IdealParams(args, system): + for cpu in system.cpu: + + # fetch + cpu.mmu.itb.size = 96 + cpu.fetchWidth = 32 + cpu.iewToFetchDelay = 2 # for resolved update, should train branch after squash + cpu.commitToFetchDelay = 2 + cpu.fetchQueueSize = 64 + + # decode + cpu.fetchToDecodeDelay = 5 + cpu.decodeWidth = 8 + cpu.enable_loadFusion = False + cpu.enableConstantFolding = False + + # rename + cpu.renameWidth = 8 + cpu.numPhysIntRegs = 224 + cpu.numPhysFloatRegs = 256 + + # dispatch + cpu.enableDispatchStage = False + cpu.numDQEntries = [8, 8, 8] + cpu.dispWidth = [8, 8, 8] + + # scheduler + cpu.scheduler = KMHV3Scheduler() + + # rob + cpu.commitWidth = 12 + cpu.squashWidth = 12 + cpu.phyregReleaseWidth = 8 + cpu.RobCompressPolicy = 'kmhv3' + cpu.numROBEntries = 160 + cpu.CROB_instPerGroup = 2 # 1 if not using ROB compression + + # lsu + cpu.StoreWbStage = 4 + cpu.EnableLdMissReplay = True + cpu.EnablePipeNukeCheck = True + cpu.BankConflictCheck = True + cpu.sbufferBankWriteAccurately = True + cpu.DcacheSetDivNum = 2 + + # value predictor + cpu.valuePred = IdealConstantLVP() + + # lsq + cpu.LQEntries = 128 + cpu.SQEntries = 64 + cpu.RARQEntries = 96 + cpu.RAWQEntries = 56 + cpu.LoadCompletionWidth = 8 + cpu.StoreCompletionWidth = 4 + cpu.RARDequeuePerCycle = 4 + cpu.RAWDequeuePerCycle = 4 + cpu.SbufferEntries = 24 + cpu.SbufferEvictThreshold = 16 + cpu.store_prefetch_train = False + + # branch predictor + if args.bp_type == 'DecoupledBPUWithBTB': + cpu.branchPred.ftq_size = 64 + cpu.branchPred.fsq_size = 64 + # TAGE table sizes and numWays tuning + cpu.branchPred.tage.tableSizes = [2048, 2048, 8192, 8192, 8192, 8192, 8192, 2048] + cpu.branchPred.tage.numWays = [2, 2, 4, 2, 2, 2, 2, 2] + # cpu.branchPred.microtage.enabled = False + + # l1 cache per core + if args.caches: + cpu.icache.size = '64kB' + cpu.dcache.size = '64kB' + cpu.dcache.tag_load_read_ports = 100 + cpu.dcache.mshrs = 16 + cpu.dcache.simulate_dcache_refill = True + set_lsq_bank_conflict_cache_params(cpu, system) + + # l2 caches + if args.l2cache: + for i in range(args.num_cpus): + if args.classic_l2: + system.l2_caches[i].slice_num = 0 # 4 -> 0, no slice + else: + l2_wrapper = system.l2_wrappers[i] + l2_wrapper.data_sram_banks = 2 + l2_wrapper.dir_sram_banks = 2 + l2_wrapper.pipe_dir_write_stage = 4 + l2_wrapper.dir_read_bypass = True + for j in range(args.l2_slices): + # Configure XSDRRIP replacement policy (DRRIP mode) + # Each slice: 2MB/4 = 512KB, 8-way, 64B line -> 1024 sets + l2_wrapper.slices[j].inner_cache.replacement_policy = XSDRRIPRP(mode=2, num_sets=1024) + system.tol2bus_list[i].forward_latency = 0 # 3->0 + system.tol2bus_list[i].response_latency = 0 # 3->0 + system.tol2bus_list[i].hint_wakeup_ahead_cycles = 0 # 2->0 + + # ReqLayer[0]: ICache+DCache+ITB+DTB -> L2, allow 2 requests per cycle + # RespLayer[1]: L2 -> DCache, allow 2 responses per cycle + system.tol2bus_list[i].layer_bandwidth_configs = [ + LayerBandwidthConfig(direction="req", port_index=0, max_per_cycle=2), + LayerBandwidthConfig(direction="resp", port_index=1, max_per_cycle=2), + ] + + # l3 cache + if args.l3cache: + system.l3.mshrs = 128 + + +if __name__ == '__m5_main__': + FutureClass = None + + args = xiangshan_system_init() + + assert not args.external_memory_system + + # AI performance runs use the same ideal KMHV3 CPU/cache envelope as + # idealkmhv3.py, plus explicit coarse CUTE timing defaults. + args.bp_type = 'DecoupledBPUWithBTB' + args.l2_size = '2MB' + args.l3_size = '32MB' + args.enable_pf_buffer = False + args.enable_riscv_vector = True + setAiMatrixTimingDefaults(args) + + # Match the memories with the CPUs, based on the options for the test system + TestMemClass = Simulation.setMemClass(args) + + test_sys = build_xiangshan_system(args) + if args.raw_cpt and args.generic_rv_cpt and os.path.basename(args.generic_rv_cpt) == "linux.bin": + configure_xiangshan_linux_workload(test_sys, args) + + # Set ideal parameters here with the highest priority, over command-line arguments + setAiKmhV3IdealParams(args, test_sys) + + root = Root(full_system=True, system=test_sys) + + Simulation.run_vanilla(args, root, test_sys, FutureClass) diff --git a/configs/example/se.py b/configs/example/se.py index b7403fd464..1cda480e8d 100644 --- a/configs/example/se.py +++ b/configs/example/se.py @@ -124,7 +124,23 @@ def get_processes(args): if '--ruby' in sys.argv: Ruby.define_options(parser) -def setDefaultArgs(args): +def explicitOptionDests(parser, argv): + option_to_dest = {} + for action in parser._actions: + for option in action.option_strings: + option_to_dest[option] = action.dest + + explicit = set() + for arg in argv[1:]: + if arg == '--': + break + option = arg.split('=', 1)[0] + dest = option_to_dest.get(option) + if dest is not None: + explicit.add(dest) + return explicit + +def setDefaultArgs(args, explicit_options): """Set default configurations to match xiangshan.py SE mode defaults""" # Set defaults only if not already specified by user @@ -154,7 +170,10 @@ def setDefaultArgs(args): } # default warmup 100k instructions! for key, value in defaults.items(): - # if not hasattr(args, key) or getattr(args, key) is None: + if key in explicit_options: + continue + if key == 'l3cache' and 'no_l3cache' in explicit_options: + continue setattr(args, key, value) # Set dramsim3_ini path @@ -165,7 +184,7 @@ def setDefaultArgs(args): args = parser.parse_args() # Set default configurations -setDefaultArgs(args) +setDefaultArgs(args, explicitOptionDests(parser, sys.argv)) multiprocesses = [] numThreads = 1 diff --git a/docs/Gem5_Docs/xsai/se_matrix_smoke.md b/docs/Gem5_Docs/xsai/se_matrix_smoke.md index 907dffe15f..04f48968fa 100644 --- a/docs/Gem5_Docs/xsai/se_matrix_smoke.md +++ b/docs/Gem5_Docs/xsai/se_matrix_smoke.md @@ -22,45 +22,55 @@ 包含: - 当前 smoke 路径需要的最小 matrix 指令 decode -- `RiscvISA::ISA` 中的最小 matrix 架构状态 +- `MatrixController` 中的最小 matrix 架构状态,`RiscvISA::ISA` 保留适配层 - matrix tile 配置 - matrix load/store helper - `int8 x int8 -> int32 accumulate` 的最小功能模型 - `msyncreset` / `mrelease` / `macquire` 的最小 token 语义 - SE 模式下为了 syscall/trap 正确性而关闭 rename folding 的配置修复 +- CUTE-aligned `MatrixController` 行为级控制骨架 +- 基于 fixed/analytic ready tick 的 matrix issue/scoreboard/token timing,可通过 + `mrelease/macquire` 观察 acquire stall +- 行为级 CUTE-to-L2 request pipeline:LocalMMU 仲裁、source id 占用、 + TL-A 请求带宽、读数据/写 ack 响应端口 +- matrix 指令映射到 O3 op class,能参与 O3 issue/execute 资源模型 +- `system.cpu.isa.*` 下的 matrix task、memory、timing、acquire stall 统计 +- RiscvISA 参数化的 matrix analytic timing 常量 不包含: - 完整 AME 指令覆盖 - FS/raw Linux bring-up -- 周期准确的 matrix 执行时序模型 -- matrix 单元和 cache hierarchy 的真实时序交互 +- 周期准确的 matrix 执行时序模型或 RTL cycle-accurate CUTE 模型 +- matrix 单元和 cache hierarchy 的真实 packet 级时序交互 +- 真实 cache miss、DRAM backpressure、TL retry 的逐级 RTL 时序展开 - 比当前 smoke 程序更广的 matrix 指令族 ## 当前实现思路 -### 1. matrix 状态放在 `RiscvISA::ISA` +### 1. matrix 状态放在 `MatrixController` -当前 SE smoke 走的是一个简单的功能模型,状态直接保存在 -`RiscvISA::ISA` 里: +当前 SE smoke 仍保留简单功能模型,但状态已经从 `RiscvISA::ISA` +迁移到 `src/matrix/MatrixController` 中,ISA 只保留 decoder 调用适配层: - tile 尺寸 - - `matrixTileM` - - `matrixTileK` - - `matrixTileN` + - `tileM` + - `tileK` + - `tileN` - matrix 数据缓冲 - - `matrixTileA` - - `matrixTileB` - - `matrixAcc` + - AB matrix register storage + - C accumulator register storage - token 状态 - - `matrixTokens` + - 32 个 matrix token -这对当前 smoke 用例已经足够,因为当前只关心功能正确性,不关心精细时序。 +这对当前 smoke 用例已经足够,因为当前首先保证功能正确性,同时把控制面 +拆出来作为后续性能模型的挂点。 可以把它理解成: - 当前先把 matrix 单元建成一个 gem5 内部的功能模型 - 先保证用户态程序能“算对、跑通” -- 暂时不建真正的硬件时序 +- 控制面开始维护 CUTE-aligned FIFO、scoreboard、LocalMMU 统计和行为级 + ready tick ### 2. matrix 内存访问走 translating proxy @@ -75,26 +85,120 @@ - 当前 matrix load/store 主要是按“功能正确”的方式访问内存 - 目标是保证用户态程序看到正确结果 -- 不是去模拟一个真实的 AME memory pipeline +- 功能数据仍在指令执行时同步产生或写回 +- matrix 数据计算不做 cycle 级数据通路建模,也不按 `M/N/K` 展开计算流水; + 只保留固定/粗粒度 ready tick + 以维持依赖和 token 语义 因此当前实现里: - 计算本身是功能模型 - 和 cache 的交互也主要是功能正确优先 -- 没有给 AME 指令补单独的执行延迟 -- 没有做 matrix 单元与 LSU/L2 的时序建模 - -### 3. token 语义是最小功能版 - -当前 token 模型是故意简化的: +- AME 计算指令只有固定抽象 timing,用于控制 token/acquire 的可观测 stall +- matrix load/store/mmacc 同时映射到 O3 `MemReadOp` / `MemWriteOp` / + `IntMultOp`,配置和同步类指令映射到 `IntAluOp` +- 默认 CUTE 访存不向共享 L2 发逐请求 timing probe packet;只保留 + controller 内部的高性能抽象 request/source/response timing + +### 3. CUTE-to-L2 请求管理是行为级管线 + +matrix load/store 的性能侧不再只用“请求数乘固定延迟”的总量公式,而是在 +`MatrixController` 中按请求推进一个抽象 LocalMMU/L2 管线: + +- LocalMMU 仲裁延迟:`matrix_local_mmu_arb_cycles` +- CUTE-to-L2 请求管线延迟:`matrix_l2_request_pipeline_cycles` +- LocalMMU source id 数量限制,source 在读响应或写 ack 返回前保持占用 +- source id 分配按 RTL `Cute2TL` 的空闲 source 选择方式抽象:有空闲 source + 时取最高编号,source 全满时等待最早释放的 source;选择点是抽象 TL-A + issue slot 的 fire tick,而不是更早的 request-ready tick +- TL-A 请求带宽:`matrix_local_mmu_issue_per_cycle` +- C store request 数按 MatrixMN rounding、stride 和 64B 外部传输粒度切分, + 避免低估非对齐 store 的抽象访存压力 +- 读响应和写 ack 共用一个抽象响应端口,服务间隔由 + `matrix_l2_response_pipeline_cycles` 控制 +- response 完成后在 controller 内更新 completion/store barrier/token timing; + 不再建模端口侧 response queue 的 1-2 拍可见边界 + +这和 RTL 的 LocalMMU/Cute2TL 控制语义只做到必要的行为级近似,不追逐 +1-2 拍的 requester-port 细节。默认生产路径不再把 CUTE 接成一个真实 +`matrix_l2_port` timing requester,也不会向 shared L2 发送逐 64B probe +packet。controller 内部仍保留 request count、LocalMMU issue throughput、 +source outstanding、读/写响应延迟、统一 response slot 和 token/acquire +可见 timing,用这些粗粒度约束表达对性能有明显影响的长延迟访存行为。 + +需要注意的是,RTL 里 CUTE 没有自己的 L2 cache,架构上仍是共享 CPU 的 +L2 fabric;但 gem5 默认模型现在只在 controller 内部抽象这条路径,不连接 +独立的 CUTE requester port。旧的 `MatrixL2RequestPort`、逐请求 timing +callback、端口侧 response/helper 和 `memoryL2Port*` 统计已经删除;后续若要 +重做 shared-L2 探针,应该作为独立实验路径重新接入,不能混入默认模型。 + +### 4. token 语义接入行为级 timing + +当前 token 模型仍是同步 smoke 可用的简化版,但 `mrelease` 不再总是立即 +增加 token,而是生成一个 ready tick: - `msyncreset(tok)`:把 token 清零 -- `mrelease(tok)`:token 加一 -- `macquire(tok, target)`:要求 `token >= target` - -这不是完整的异步完成模型,只是当前 smoke 所需的最小契约。 - -### 4. SE 模式关闭 rename folding +- `mrelease(tok)`:等待前序已排 matrix task 的 analytic completion tick + 后产生 token event +- `macquire(tok, target)`:若 token 未达到 target,则按 controller 内部 + 预测 ready tick quiesce;默认不再等待 CUTE L2 port in-flight 状态 + +这不是完整的异步完成模型,但已经能把 CUTE matrix-release-acquire +控制路径反映到 gem5 仿真时间中。当前 `mrelease` 在行为级模型中按 +conservative barrier 处理,不表示 RTL 精确 release 条件。 + +### 5. 可调参数与统计 + +`RiscvISA` 暴露了当前 analytic timing 常量: + +- `matrix_issue_interval_cycles` +- `matrix_load_base_cycles` +- `matrix_store_base_cycles` +- `matrix_zero_cycles` +- `matrix_compute_base_cycles`,固定 compute ready latency,不随 tile shape 展开 +- `matrix_compute_read_cycles`,固定 compute source read latency +- `matrix_release_cycles` +- `matrix_local_mmu_issue_per_cycle` +- `matrix_local_mmu_arb_cycles` +- `matrix_l2_request_pipeline_cycles` +- `matrix_l2_response_pipeline_cycles` +- `matrix_local_mmu_read_latency_cycles` +- `matrix_local_mmu_write_ack_latency_cycles` + +其中对行为级性能趋势影响最大的旋钮是 LocalMMU request issue throughput、 +CUTE-to-L2 request/response pipeline latency、读响应/写 ack latency、load/store +base latency 和 release barrier latency。它们在通用配置脚本中也有对应的命令 +行入口: + +- `--matrix-local-mmu-issue-per-cycle` +- `--matrix-l2-request-pipeline-cycles` +- `--matrix-l2-response-pipeline-cycles` +- `--matrix-local-mmu-read-latency-cycles` +- `--matrix-local-mmu-write-ack-latency-cycles` +- `--matrix-load-base-cycles` +- `--matrix-store-base-cycles` +- `--matrix-zero-cycles` +- `--matrix-release-cycles` +- `--matrix-compute-base-cycles` +- `--matrix-compute-read-cycles` + +这些选项用于粗粒度敏感性分析;不要把它们理解成 RTL 每一级 FIFO 或每一级 +流水 stage 的精确控制。 + +运行后可在 `stats.txt` 的 `system.cpu.isa.*` 下观察: + +- task 计数:`tasksAccepted`、`mmaTasks`、`cStoreTasks` +- LocalMMU 行为级统计:`memoryRequests`、`memoryBusBytes` +- CUTE-to-L2 管线统计:`memoryPipelineRequests`、 + `memoryPipelineReadResponses`、`memoryPipelineWriteAcks`、 + `memoryPipelineSourceStallTicks`、`memoryPipelineRequestQueueTicks`、 + `memoryPipelineResponseQueueTicks`、`memoryPipelineMaxOutstanding` +- 默认模型不会驱动 shared-L2 requester 端口,也不再注册 `memoryL2Port*` + 旧探针计数器 +- timing 统计:`timingTasks`、`timingQueueTicks`、`timingBusyTicks` +- 同步统计:`acquireStallEvents`、`acquireStallTicks`、`tokenReleaseEvents` + +### 6. SE 模式关闭 rename folding `configs/example/se.py` 里会关闭: @@ -160,6 +264,30 @@ SE 下 trap/syscall 对寄存器的写回绕过了正常的 renamed writeback - allocator 测试全部通过 - 能进入 randomized matrix fuzz - 在扩大的 timeout 窗口内没有完整跑完,但在 timeout 前没有观察到 correctness failure +- MatrixController 行为/timing 单测 + - `scons --unit-test build/RISCV/matrix/matrix_controller.test.opt -j16` + - `build/RISCV/matrix/matrix_controller.test.opt` + - 11/11 通过 + - 覆盖 `mmacc.w.b` 数据结果、issue/scoreboard timing、`mrelease` + barrier、issue interval clamp、LocalMMU source id 耗尽、读响应/写 ack + 分流、source 在抽象 TL-A issue tick 选择、控制面 LocalMMU allocator + 按最高空闲 source 分配、非对齐 C store 的 MatrixMN rounding/stride + request 切分和 pending token reset。 +- 完整 RISCV 目标构建 + - `scons build/RISCV/gem5.opt -j16` + - 通过 + - 仅有 libpng、HDF5、backtrace 相关可选依赖警告 +- `coremark-2-iteration.bin` 1000 指令短仿真 + - 正常以 max instruction count 退出 + - `config.ini` 可见 matrix analytic timing 与 L2 request pipeline 参数 + - `stats.txt` 可见 `system.cpu.isa.*` 下 matrix timing/task/L2 pipeline 统计 + - 该 workload 不含 matrix 指令,因此只验证参数和统计 plumbing +- 2026-05-25 默认抽象路径短仿真 + - `/tmp/gem5-cute-abstract-cut2` + - 正常以 max instruction count 退出 + - `config.ini` 不再出现 `system.cpu.isa.matrix_l2_port` + - `stats.txt` 不再出现 `memoryL2Port*` 旧探针统计 +- 历史验证里曾跑过 `coremark` 短仿真和 full-gem5 plumbing 检查;这些结果只能证明旧的探针路径曾经可接通,不能代表当前默认模型。 ## 一条典型运行命令 @@ -175,9 +303,25 @@ cd GEM5 ## 已知限制 - 当前还是 smoke-oriented 的功能模型,不是完整 AME 模型 -- token 语义还是最小功能版,没有和真正异步完成路径绑定 -- 当前主要是“功能支持”,还没有做 matrix 指令的时序模拟 -- 当前计算和 cache 交互主要是功能正确、无延迟的形式 +- token 初始 ready tick 来自行为级 analytic timing;默认模型没有 + CUTE L2 port 的 in-flight request/response/source 状态 +- `mrelease` 当前在行为级模型中作为前序 matrix task 的 conservative + barrier,不是 RTL 精确 release 条件 +- 当前 matrix timing 只能用于粗粒度行为级 timing 探索;compute 侧是固定 + ready tick,不用于性能评估,访存侧是 CUTE-to-shared-L2 request 管理的 + 行为级性能模型 +- 默认 CUTE requester 端口不接入 `L1ToL2Bus`/`L2CacheWrapper`,也不处理 + xbar retry、真实 cache response 或端口侧 response queue +- 旧的端口 reset/drain、response 回写逻辑和端口侧测试已经从代码中删除 +- shared-L2 backpressure/retry/response 不作为当前默认行为级反馈输入; + 真实 cache/DRAM latency、MSHR 占用、TL retry 没有逐级展开成 RTL 级事件 +- 当前计算和内存功能结果主要是功能正确优先;计算 timing 不随数据通路 + cycle 展开,访存 timing 通过 + scoreboard/token/acquire 推进仿真时间,不改变同步功能数据路径 +- CStore 只保留抽象 request 数和 response timing,不发送 shared-L2 + `WriteReq` packet +- `matrixsimpletest-riscv64-xs.bin` 当前可进入 real simulation,但仍会早期触发 + 既有 O3 commit-stuck,不能作为 matrix timing 正确性回归依据 - `hello_xsai` 的 randomized fuzz 明显比核心 smoke 用例更耗时 - FS/Linux 支持应该作为后续独立工作推进 @@ -188,4 +332,6 @@ cd GEM5 1. 把 `libc_mmap_smoke_xsai`、`precomp_rand_repro`、`gemm_precomp` 固化为回归基线 2. 把后续主要 bring-up 重心转到 FS -3. 只有在真实 workload 需要时,再继续扩展 matrix 指令覆盖和时序模型 +3. 用可稳定运行的 matrix workload 校准 controller 内部抽象 LocalMMU/L2 + timing 参数,并和 RTL 关键吞吐/队列计数对齐 +4. 只有在真实 workload 需要时,再继续扩展 matrix 指令覆盖和时序模型 diff --git a/src/arch/riscv/RiscvISA.py b/src/arch/riscv/RiscvISA.py index a54dcfd2a1..b9961b8c91 100644 --- a/src/arch/riscv/RiscvISA.py +++ b/src/arch/riscv/RiscvISA.py @@ -39,8 +39,36 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. from m5.objects.BaseISA import BaseISA +from m5.params import * class RiscvISA(BaseISA): type = 'RiscvISA' cxx_class = 'gem5::RiscvISA::ISA' cxx_header = "arch/riscv/isa.hh" + + matrix_issue_interval_cycles = Param.Unsigned( + 1, "Analytic CUTE matrix issue interval in CPU cycles") + matrix_load_base_cycles = Param.Unsigned( + 4, "Analytic CUTE matrix load base latency in CPU cycles") + matrix_store_base_cycles = Param.Unsigned( + 4, "Analytic CUTE matrix store base latency in CPU cycles") + matrix_zero_cycles = Param.Unsigned( + 1, "Analytic CUTE matrix zero latency in CPU cycles") + matrix_compute_base_cycles = Param.Unsigned( + 2, "Fixed abstract CUTE matrix compute ready latency in CPU cycles") + matrix_compute_read_cycles = Param.Unsigned( + 1, "Fixed abstract CUTE matrix compute source read latency in CPU cycles") + matrix_release_cycles = Param.Unsigned( + 1, "Analytic CUTE matrix release latency in CPU cycles") + matrix_local_mmu_issue_per_cycle = Param.Unsigned( + 1, "Analytic CUTE LocalMMU request issue throughput per CPU cycle") + matrix_local_mmu_arb_cycles = Param.Unsigned( + 1, "Analytic CUTE LocalMMU arbitration latency in CPU cycles") + matrix_l2_request_pipeline_cycles = Param.Unsigned( + 1, "Analytic CUTE-to-L2 request pipeline latency in CPU cycles") + matrix_l2_response_pipeline_cycles = Param.Unsigned( + 1, "Analytic CUTE L2 response port service interval in CPU cycles") + matrix_local_mmu_read_latency_cycles = Param.Unsigned( + 20, "Analytic CUTE LocalMMU read response latency in CPU cycles") + matrix_local_mmu_write_ack_latency_cycles = Param.Unsigned( + 12, "Analytic CUTE LocalMMU write acknowledgement latency in CPU cycles") diff --git a/src/arch/riscv/insts/SConscript b/src/arch/riscv/insts/SConscript index e57832f98d..4aeee46a9b 100644 --- a/src/arch/riscv/insts/SConscript +++ b/src/arch/riscv/insts/SConscript @@ -30,7 +30,6 @@ Import('*') Source('amo.cc', tags='riscv isa') Source('compressed.cc', tags='riscv isa') Source('mem.cc', tags='riscv isa') -Source('matrix.cc', tags='riscv isa') Source('standard.cc', tags='riscv isa') Source('static_inst.cc', tags='riscv isa') Source('vector.cc', tags='riscv isa') diff --git a/src/arch/riscv/insts/matrix.cc b/src/arch/riscv/insts/matrix.cc deleted file mode 100644 index 0bd84999c9..0000000000 --- a/src/arch/riscv/insts/matrix.cc +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Minimal AME helpers for XS-GEM5 bring-up. - */ - -#include "arch/riscv/insts/matrix.hh" - -namespace gem5 -{ - -namespace RiscvISA -{ - -uint32_t -clampMatrixTileM(uint64_t value) -{ - return value > MatrixMaxM ? MatrixMaxM : static_cast(value); -} - -uint32_t -clampMatrixTileK(uint64_t value) -{ - return value > MatrixMaxK ? MatrixMaxK : static_cast(value); -} - -uint32_t -clampMatrixTileN(uint64_t value) -{ - return value > MatrixMaxN ? MatrixMaxN : static_cast(value); -} - -} // namespace RiscvISA -} // namespace gem5 diff --git a/src/arch/riscv/insts/matrix.hh b/src/arch/riscv/insts/matrix.hh deleted file mode 100644 index 23cd3b6cfc..0000000000 --- a/src/arch/riscv/insts/matrix.hh +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Minimal AME helpers for XS-GEM5 bring-up. - */ - -#ifndef __ARCH_RISCV_INSTS_MATRIX_HH__ -#define __ARCH_RISCV_INSTS_MATRIX_HH__ - -#include - -namespace gem5 -{ - -namespace RiscvISA -{ - -static constexpr uint32_t MatrixMaxM = 128; -static constexpr uint32_t MatrixMaxK = 64; -static constexpr uint32_t MatrixMaxN = 128; - -static constexpr uint32_t MatrixTileABytes = MatrixMaxM * MatrixMaxK; -static constexpr uint32_t MatrixTileBBytes = MatrixMaxN * MatrixMaxK; -static constexpr uint32_t MatrixAccElems = MatrixMaxM * MatrixMaxN; - -uint32_t clampMatrixTileM(uint64_t value); -uint32_t clampMatrixTileK(uint64_t value); -uint32_t clampMatrixTileN(uint64_t value); - -} // namespace RiscvISA -} // namespace gem5 - -#endif // __ARCH_RISCV_INSTS_MATRIX_HH__ diff --git a/src/arch/riscv/isa.cc b/src/arch/riscv/isa.cc index 730af71a98..343676641c 100644 --- a/src/arch/riscv/isa.cc +++ b/src/arch/riscv/isa.cc @@ -30,13 +30,14 @@ #include "arch/riscv/isa.hh" +#include #include +#include #include #include #include "arch/riscv/interrupts.hh" #include "arch/riscv/mmu.hh" -#include "arch/riscv/pagetable.hh" #include "arch/riscv/pmp.hh" #include "arch/riscv/regs/float.hh" #include "arch/riscv/regs/int.hh" @@ -49,6 +50,7 @@ #include "base/trace.hh" #include "base/types.hh" #include "cpu/base.hh" +#include "cpu/thread_context.hh" #include "debug/Checkpoint.hh" #include "debug/FloatRegs.hh" #include "debug/IntRegs.hh" @@ -56,10 +58,8 @@ #include "debug/MiscRegs.hh" #include "debug/RiscvMisc.hh" #include "debug/VecRegs.hh" -#include "mem/packet.hh" +#include "matrix/matrix_controller.hh" #include "mem/request.hh" -#include "mem/se_translating_port_proxy.hh" -#include "mem/translating_port_proxy.hh" #include "params/RiscvISA.hh" #include "sim/faults.hh" #include "sim/full_system.hh" @@ -74,43 +74,25 @@ namespace RiscvISA namespace { -Fault -matrixReadBlob(ThreadContext *tc, Addr addr, void *dst, size_t size) -{ - bool ok = false; - if (FullSystem) { - TranslatingPortProxy proxy(tc); - ok = proxy.tryReadBlob(addr, dst, size); - } else { - SETranslatingPortProxy proxy(tc); - ok = proxy.tryReadBlob(addr, dst, size); - } - - if (!ok) { - return std::make_shared(addr); - } - return NoFault; -} +constexpr RegVal MatrixCapabilityBits = 1; +constexpr int MatrixMxrmOffset = 0; +constexpr int MatrixMsatOffset = 2; +constexpr int MatrixMfflagsOffset = 3; +constexpr int MatrixMfrmOffset = 8; +constexpr int MatrixMsatenOffset = 11; -Fault -matrixWriteBlob(ThreadContext *tc, Addr addr, const void *src, size_t size) +RegVal +composeMatrixCSR(RegVal mxrm, RegVal msat, RegVal mfflags, RegVal mfrm, + RegVal msaten) { - bool ok = false; - if (FullSystem) { - TranslatingPortProxy proxy(tc); - ok = proxy.tryWriteBlob(addr, src, size); - } else { - SETranslatingPortProxy proxy(tc); - ok = proxy.tryWriteBlob(addr, src, size); - } - - if (!ok) { - return std::make_shared(addr); - } - return NoFault; + return ((mxrm & MCSR_MXRM_MASK) << MatrixMxrmOffset) | + ((msat & MCSR_MSAT_MASK) << MatrixMsatOffset) | + ((mfflags & MCSR_MFFLAGS_MASK) << MatrixMfflagsOffset) | + ((mfrm & MCSR_MFRM_MASK) << MatrixMfrmOffset) | + ((msaten & MCSR_MSATEN_MASK) << MatrixMsatenOffset); } -} // namespace +} // anonymous namespace [[maybe_unused]] const std::array MiscRegNames = {{ [MISCREG_PRV] = "PRV", @@ -310,12 +292,56 @@ matrixWriteBlob(ThreadContext *tc, Addr addr, const void *src, size_t size) [MISCREG_NMIVEC] = "NMIVEC", [MISCREG_NMIE] = "NMIE", [MISCREG_NMIP] = "NMIP", + + [MISCREG_MCSR] = "MCSR", + [MISCREG_MXRM] = "MXRM", + [MISCREG_MSAT] = "MSAT", + [MISCREG_MFFLAGS] = "MFFLAGS", + [MISCREG_MFRM] = "MFRM", + [MISCREG_MSATEN] = "MSATEN", + [MISCREG_MTYPE] = "MTYPE", + [MISCREG_MTILEM] = "MTILEM", + [MISCREG_MTILEN] = "MTILEN", + [MISCREG_MTILEK] = "MTILEK", + [MISCREG_MLENB] = "MLENB", + [MISCREG_MRLENB] = "MRLENB", + [MISCREG_MAMUL] = "MAMUL", + [MISCREG_XMISA] = "XMISA", + [MISCREG_XTLENB] = "XTLENB", + [MISCREG_XTRLENB] = "XTRLENB", + [MISCREG_XALENB] = "XALENB", + [MISCREG_MTOK] = "MTOK", + [MISCREG_XMTILEM] = "XMTILEM", + [MISCREG_XMTILEN] = "XMTILEN", + [MISCREG_XMTILEK] = "XMTILEK", }}; -ISA::ISA(const Params &p) : BaseISA(p) +ISA::ISA(const Params &p) : BaseISA(p), + matrixController(std::make_unique(this)) { + matrix::MatrixController::TimingConfig matrix_timing; + matrix_timing.issueIntervalCycles = p.matrix_issue_interval_cycles; + matrix_timing.loadBaseCycles = p.matrix_load_base_cycles; + matrix_timing.storeBaseCycles = p.matrix_store_base_cycles; + matrix_timing.zeroCycles = p.matrix_zero_cycles; + matrix_timing.computeBaseCycles = p.matrix_compute_base_cycles; + matrix_timing.computeReadCycles = p.matrix_compute_read_cycles; + matrix_timing.releaseCycles = p.matrix_release_cycles; + matrix_timing.localMmuIssuePerCycle = + p.matrix_local_mmu_issue_per_cycle; + matrix_timing.localMmuArbCycles = p.matrix_local_mmu_arb_cycles; + matrix_timing.l2RequestPipelineCycles = + p.matrix_l2_request_pipeline_cycles; + matrix_timing.l2ResponsePipelineCycles = + p.matrix_l2_response_pipeline_cycles; + matrix_timing.localMmuReadLatencyCycles = + p.matrix_local_mmu_read_latency_cycles; + matrix_timing.localMmuWriteAckLatencyCycles = + p.matrix_local_mmu_write_ack_latency_cycles; + matrixController->setTimingConfig(matrix_timing); + _regClasses.emplace_back(IntRegClass, int_reg::NumRegs, debug::IntRegs, sizeof(RegVal)); _regClasses.emplace_back(FloatRegClass, float_reg::NumRegs, debug::FloatRegs, sizeof(RegVal)); @@ -333,6 +359,8 @@ ISA::ISA(const Params &p) : BaseISA(p) clear(); } +ISA::~ISA() = default; + bool ISA::inUserMode() const { return miscRegFile[MISCREG_PRV] == PRV_U; @@ -389,136 +417,185 @@ void ISA::clear() miscRegFile[MISCREG_VSSTATUS] = miscRegFile[MISCREG_STATUS] & NEMU_SSTATUS_RMASK; miscRegFile[MISCREG_ARCHID] = 0x19; miscRegFile[MISCREG_VENDORID] = (16ULL << 7) | 0x6FULL; + + miscRegFile[MISCREG_MCSR] = 0; + miscRegFile[MISCREG_MXRM] = 0; + miscRegFile[MISCREG_MSAT] = 0; + miscRegFile[MISCREG_MFFLAGS] = 0; + miscRegFile[MISCREG_MFRM] = 0; + miscRegFile[MISCREG_MSATEN] = 0; + miscRegFile[MISCREG_MTYPE] = 0; + miscRegFile[MISCREG_MTILEM] = 0; + miscRegFile[MISCREG_MTILEN] = 0; + miscRegFile[MISCREG_MTILEK] = 0; + miscRegFile[MISCREG_MLENB] = matrix::MatrixController::MatrixABRegBytes; + miscRegFile[MISCREG_MRLENB] = matrix::MatrixController::ReduceWidthBytes; + miscRegFile[MISCREG_MAMUL] = + matrix::MatrixController::MatrixAccElems * + matrix::MatrixController::ResultWidthBytes; + miscRegFile[MISCREG_XMISA] = MatrixCapabilityBits; + miscRegFile[MISCREG_XTLENB] = matrix::MatrixController::MatrixABRegBytes; + miscRegFile[MISCREG_XTRLENB] = + matrix::MatrixController::ReduceWidthBytes; + miscRegFile[MISCREG_XALENB] = + matrix::MatrixController::MatrixAccElems * + matrix::MatrixController::ResultWidthBytes; + miscRegFile[MISCREG_MTOK] = matrix::MatrixController::TokenCount; + miscRegFile[MISCREG_XMTILEM] = 0; + miscRegFile[MISCREG_XMTILEN] = 0; + miscRegFile[MISCREG_XMTILEK] = 0; } void ISA::resetMatrixState() { - matrixTileM = 0; - matrixTileK = 0; - matrixTileN = 0; - matrixTileA.assign(MatrixTileABytes, 0); - matrixTileB.assign(MatrixTileBBytes, 0); - matrixAcc.assign(MatrixAccElems, 0); - matrixTokens.assign(32, 0); + matrixController->reset(); } void ISA::matrixSyncReset(uint64_t token_idx) { - matrixToken(token_idx) = 0; + matrixController->syncReset(token_idx); } void -ISA::matrixRelease(uint64_t token_idx) +ISA::matrixRelease(ExecContext *xc, uint64_t token_idx) { - ++matrixToken(token_idx); + matrixController->release(xc, token_idx); } void -ISA::matrixAcquire(uint64_t token_idx, uint64_t target) +ISA::matrixAcquire(ExecContext *xc, uint64_t token_idx, uint64_t target) { - panic_if(matrixToken(token_idx) < target, - "macquire tok%u target=%llu observed=%llu", - token_idx, target, matrixToken(token_idx)); + if (matrixController->tokenTargetReached(token_idx, target)) { + return; + } + + matrixController->acquire(xc, token_idx, target); } void ISA::setMatrixTileM(uint64_t value) { - matrixTileM = clampMatrixTileM(value); + matrixController->setTileM(value); + const auto tile_m = matrixController->getTileM(); + miscRegFile[MISCREG_MTILEM] = tile_m; + miscRegFile[MISCREG_XMTILEM] = tile_m; } void ISA::setMatrixTileK(uint64_t value) { - matrixTileK = clampMatrixTileK(value); + matrixController->setTileK(value); + const auto tile_k = matrixController->getTileK(); + miscRegFile[MISCREG_MTILEK] = tile_k; + miscRegFile[MISCREG_XMTILEK] = tile_k; } void ISA::setMatrixTileN(uint64_t value) { - matrixTileN = clampMatrixTileN(value); + matrixController->setTileN(value); + const auto tile_n = matrixController->getTileN(); + miscRegFile[MISCREG_MTILEN] = tile_n; + miscRegFile[MISCREG_XMTILEN] = tile_n; +} + +uint32_t +ISA::getMatrixTileM() const +{ + return matrixController->getTileM(); +} + +uint32_t +ISA::getMatrixTileK() const +{ + return matrixController->getTileK(); +} + +uint32_t +ISA::getMatrixTileN() const +{ + return matrixController->getTileN(); } Fault ISA::matrixLoadA8(ExecContext *xc, Addr base, Addr stride) { - ThreadContext *tc = xc->tcBase(); - for (uint32_t row = 0; row < matrixTileM; ++row) { - auto *dst = reinterpret_cast(&matrixTileA[row * MatrixMaxK]); - Fault fault = matrixReadBlob(tc, base + row * stride, dst, matrixTileK); - if (fault != NoFault) { - return fault; - } - } - return NoFault; + return matrixLoadA8(xc, base, stride, + matrix::MatrixController::DefaultAReg); +} + +Fault +ISA::matrixLoadA8(ExecContext *xc, Addr base, Addr stride, uint64_t reg_idx) +{ + return matrixController->loadA8(xc, base, stride, reg_idx); } Fault ISA::matrixLoadB8(ExecContext *xc, Addr base, Addr stride) { - ThreadContext *tc = xc->tcBase(); - for (uint32_t row = 0; row < matrixTileN; ++row) { - auto *dst = reinterpret_cast(&matrixTileB[row * MatrixMaxK]); - Fault fault = matrixReadBlob(tc, base + row * stride, dst, matrixTileK); - if (fault != NoFault) { - return fault; - } - } - return NoFault; + return matrixLoadB8(xc, base, stride, + matrix::MatrixController::DefaultBReg); +} + +Fault +ISA::matrixLoadB8(ExecContext *xc, Addr base, Addr stride, uint64_t reg_idx) +{ + return matrixController->loadB8(xc, base, stride, reg_idx); } Fault ISA::matrixLoadC32(ExecContext *xc, Addr base, Addr stride) { - ThreadContext *tc = xc->tcBase(); - for (uint32_t row = 0; row < matrixTileM; ++row) { - auto *dst = reinterpret_cast(&matrixAcc[row * MatrixMaxN]); - Fault fault = matrixReadBlob( - tc, base + row * stride, dst, matrixTileN * sizeof(int32_t)); - if (fault != NoFault) { - return fault; - } - } - return NoFault; + return matrixLoadC32(xc, base, stride, + matrix::MatrixController::DefaultAccReg); +} + +Fault +ISA::matrixLoadC32(ExecContext *xc, Addr base, Addr stride, uint64_t acc_idx) +{ + return matrixController->loadC32(xc, base, stride, acc_idx); } Fault ISA::matrixStoreC32(ExecContext *xc, Addr base, Addr stride) { - ThreadContext *tc = xc->tcBase(); - for (uint32_t row = 0; row < matrixTileM; ++row) { - auto *src = reinterpret_cast(&matrixAcc[row * MatrixMaxN]); - Fault fault = matrixWriteBlob( - tc, base + row * stride, src, matrixTileN * sizeof(int32_t)); - if (fault != NoFault) { - return fault; - } - } - return NoFault; + return matrixStoreC32(xc, base, stride, + matrix::MatrixController::DefaultAccReg); +} + +Fault +ISA::matrixStoreC32(ExecContext *xc, Addr base, Addr stride, uint64_t acc_idx) +{ + return matrixController->storeC32(xc, base, stride, acc_idx); } void -ISA::matrixZeroAcc() +ISA::matrixZeroAcc(ExecContext *xc) { - std::fill(matrixAcc.begin(), matrixAcc.end(), 0); + matrixZeroAcc(xc, matrix::MatrixController::DefaultAccReg); } void -ISA::matrixMMAccWB() +ISA::matrixZeroAcc(ExecContext *xc, uint64_t acc_idx) { - for (uint32_t m = 0; m < matrixTileM; ++m) { - for (uint32_t n = 0; n < matrixTileN; ++n) { - int32_t acc = matrixAcc[m * MatrixMaxN + n]; - for (uint32_t k = 0; k < matrixTileK; ++k) { - int8_t a = matrixTileA[m * MatrixMaxK + k]; - int8_t b = matrixTileB[n * MatrixMaxK + k]; - acc += static_cast(a) * static_cast(b); - } - matrixAcc[m * MatrixMaxN + n] = acc; - } - } + matrixController->zeroAcc(xc, acc_idx); +} + +void +ISA::matrixMMAccWB(ExecContext *xc) +{ + matrixMMAccWB(xc, matrix::MatrixController::DefaultAReg, + matrix::MatrixController::DefaultBReg, + matrix::MatrixController::DefaultAccReg); +} + +void +ISA::matrixMMAccWB(ExecContext *xc, uint64_t src_a_idx, uint64_t src_b_idx, + uint64_t dst_acc_idx) +{ + matrixController->mmaccWB(xc, src_a_idx, src_b_idx, dst_acc_idx); } bool @@ -666,6 +743,63 @@ ISA::readMiscReg(int misc_reg) return VLENB; } break; + case MISCREG_MCSR: + { + return composeMatrixCSR( + miscRegFile[MISCREG_MXRM], + miscRegFile[MISCREG_MSAT], + miscRegFile[MISCREG_MFFLAGS], + miscRegFile[MISCREG_MFRM], + miscRegFile[MISCREG_MSATEN]); + } + break; + case MISCREG_MTILEM: + case MISCREG_XMTILEM: + { + return matrixController->getTileM(); + } + break; + case MISCREG_MTILEN: + case MISCREG_XMTILEN: + { + return matrixController->getTileN(); + } + break; + case MISCREG_MTILEK: + case MISCREG_XMTILEK: + { + return matrixController->getTileK(); + } + break; + case MISCREG_MLENB: + case MISCREG_XTLENB: + { + return matrix::MatrixController::MatrixABRegBytes; + } + break; + case MISCREG_MRLENB: + case MISCREG_XTRLENB: + { + return matrix::MatrixController::ReduceWidthBytes; + } + break; + case MISCREG_MAMUL: + case MISCREG_XALENB: + { + return matrix::MatrixController::MatrixAccElems * + matrix::MatrixController::ResultWidthBytes; + } + break; + case MISCREG_XMISA: + { + return MatrixCapabilityBits; + } + break; + case MISCREG_MTOK: + { + return matrix::MatrixController::TokenCount; + } + break; case MISCREG_VCSR: { return readMiscRegNoEffect(MISCREG_VXSAT) | @@ -965,6 +1099,62 @@ ISA::setMiscReg(int misc_reg, RegVal val) setMiscRegNoEffect(MISCREG_VXRM, (val & 0x6) >> 1); } break; + case MISCREG_MCSR: + { + setMiscRegNoEffect(MISCREG_MXRM, + (val >> MatrixMxrmOffset) & MCSR_MXRM_MASK); + setMiscRegNoEffect(MISCREG_MSAT, + (val >> MatrixMsatOffset) & MCSR_MSAT_MASK); + setMiscRegNoEffect(MISCREG_MFFLAGS, + (val >> MatrixMfflagsOffset) & MCSR_MFFLAGS_MASK); + setMiscRegNoEffect(MISCREG_MFRM, + (val >> MatrixMfrmOffset) & MCSR_MFRM_MASK); + setMiscRegNoEffect(MISCREG_MSATEN, + (val >> MatrixMsatenOffset) & MCSR_MSATEN_MASK); + setMiscRegNoEffect(MISCREG_MCSR, + composeMatrixCSR(miscRegFile[MISCREG_MXRM], + miscRegFile[MISCREG_MSAT], + miscRegFile[MISCREG_MFFLAGS], + miscRegFile[MISCREG_MFRM], + miscRegFile[MISCREG_MSATEN])); + } + break; + case MISCREG_MXRM: + case MISCREG_MSAT: + case MISCREG_MFFLAGS: + case MISCREG_MFRM: + case MISCREG_MSATEN: + { + RegVal masked = val; + switch (misc_reg) { + case MISCREG_MXRM: + masked &= MCSR_MXRM_MASK; + break; + case MISCREG_MSAT: + masked &= MCSR_MSAT_MASK; + break; + case MISCREG_MFFLAGS: + masked &= MCSR_MFFLAGS_MASK; + break; + case MISCREG_MFRM: + masked &= MCSR_MFRM_MASK; + break; + case MISCREG_MSATEN: + masked &= MCSR_MSATEN_MASK; + break; + default: + panic("Unexpected matrix CSR misc reg %d\n", misc_reg); + } + + setMiscRegNoEffect(misc_reg, masked); + setMiscRegNoEffect(MISCREG_MCSR, + composeMatrixCSR(miscRegFile[MISCREG_MXRM], + miscRegFile[MISCREG_MSAT], + miscRegFile[MISCREG_MFFLAGS], + miscRegFile[MISCREG_MFRM], + miscRegFile[MISCREG_MSATEN])); + } + break; case MISCREG_VTYPE: { DPRINTF(RiscvMisc, "Will set vs\n"); @@ -999,13 +1189,7 @@ ISA::serialize(CheckpointOut &cp) const { DPRINTF(Checkpoint, "Serializing Riscv Misc Registers\n"); SERIALIZE_CONTAINER(miscRegFile); - SERIALIZE_SCALAR(matrixTileM); - SERIALIZE_SCALAR(matrixTileK); - SERIALIZE_SCALAR(matrixTileN); - SERIALIZE_CONTAINER(matrixTileA); - SERIALIZE_CONTAINER(matrixTileB); - SERIALIZE_CONTAINER(matrixAcc); - SERIALIZE_CONTAINER(matrixTokens); + matrixController->serialize(cp); } void @@ -1013,29 +1197,7 @@ ISA::unserialize(CheckpointIn &cp) { DPRINTF(Checkpoint, "Unserializing Riscv Misc Registers\n"); UNSERIALIZE_CONTAINER(miscRegFile); - UNSERIALIZE_SCALAR(matrixTileM); - UNSERIALIZE_SCALAR(matrixTileK); - UNSERIALIZE_SCALAR(matrixTileN); - UNSERIALIZE_CONTAINER(matrixTileA); - UNSERIALIZE_CONTAINER(matrixTileB); - UNSERIALIZE_CONTAINER(matrixAcc); - UNSERIALIZE_CONTAINER(matrixTokens); -} - -RegVal & -ISA::matrixToken(size_t idx) -{ - panic_if(idx >= matrixTokens.size(), "matrix token index %u out of range", - idx); - return matrixTokens[idx]; -} - -const RegVal & -ISA::matrixToken(size_t idx) const -{ - panic_if(idx >= matrixTokens.size(), "matrix token index %u out of range", - idx); - return matrixTokens[idx]; + matrixController->unserialize(cp); } const int WARN_FAILURE = 10000; diff --git a/src/arch/riscv/isa.hh b/src/arch/riscv/isa.hh index df8ce507d3..8a5fe17209 100644 --- a/src/arch/riscv/isa.hh +++ b/src/arch/riscv/isa.hh @@ -35,20 +35,23 @@ #define __ARCH_RISCV_ISA_HH__ #include +#include +#include #include #include "arch/generic/isa.hh" -#include "arch/riscv/insts/matrix.hh" #include "arch/riscv/pcstate.hh" #include "arch/riscv/types.hh" #include "base/types.hh" #include "cpu/exec_context.hh" +#include "matrix/matrix_controller.hh" namespace gem5 { struct RiscvISAParams; class Checkpoint; +class ThreadContext; namespace RiscvISA { @@ -81,19 +84,7 @@ class ISA : public BaseISA { protected: std::vector miscRegFile; - uint32_t matrixTileM = 0; - uint32_t matrixTileK = 0; - uint32_t matrixTileN = 0; - std::vector matrixTileA; - std::vector matrixTileB; - std::vector matrixAcc; - std::vector matrixTokens; - - RegVal & - matrixToken(size_t idx); - - const RegVal & - matrixToken(size_t idx) const; + std::unique_ptr matrixController; public: using Params = RiscvISAParams; @@ -129,23 +120,36 @@ class ISA : public BaseISA void unserialize(CheckpointIn &cp) override; ISA(const Params &p); + ~ISA() override; void resetMatrixState(); void matrixSyncReset(uint64_t token_idx); - void matrixRelease(uint64_t token_idx); - void matrixAcquire(uint64_t token_idx, uint64_t target); + void matrixRelease(ExecContext *xc, uint64_t token_idx); + void matrixAcquire(ExecContext *xc, uint64_t token_idx, uint64_t target); void setMatrixTileM(uint64_t value); void setMatrixTileK(uint64_t value); void setMatrixTileN(uint64_t value); - uint32_t getMatrixTileM() const { return matrixTileM; } - uint32_t getMatrixTileK() const { return matrixTileK; } - uint32_t getMatrixTileN() const { return matrixTileN; } + uint32_t getMatrixTileM() const; + uint32_t getMatrixTileK() const; + uint32_t getMatrixTileN() const; Fault matrixLoadA8(ExecContext *xc, Addr base, Addr stride); + Fault matrixLoadA8( + ExecContext *xc, Addr base, Addr stride, uint64_t reg_idx); Fault matrixLoadB8(ExecContext *xc, Addr base, Addr stride); + Fault matrixLoadB8( + ExecContext *xc, Addr base, Addr stride, uint64_t reg_idx); Fault matrixLoadC32(ExecContext *xc, Addr base, Addr stride); + Fault matrixLoadC32( + ExecContext *xc, Addr base, Addr stride, uint64_t acc_idx); Fault matrixStoreC32(ExecContext *xc, Addr base, Addr stride); - void matrixZeroAcc(); - void matrixMMAccWB(); + Fault matrixStoreC32( + ExecContext *xc, Addr base, Addr stride, uint64_t acc_idx); + void matrixZeroAcc(ExecContext *xc); + void matrixZeroAcc(ExecContext *xc, uint64_t acc_idx); + void matrixMMAccWB(ExecContext *xc); + void matrixMMAccWB( + ExecContext *xc, uint64_t src_a_idx, uint64_t src_b_idx, + uint64_t dst_acc_idx); void handleLockedRead(const RequestPtr &req) override; diff --git a/src/arch/riscv/isa/bitfields.isa b/src/arch/riscv/isa/bitfields.isa index c50ff9fd21..40bc8b5091 100644 --- a/src/arch/riscv/isa/bitfields.isa +++ b/src/arch/riscv/isa/bitfields.isa @@ -156,5 +156,14 @@ def bitfield BIT30 <30>; def bitfield SIMM5 <19:15>; def bitfield SIMM3 <17:15>; +// Matrix legacy OP_M encoding. +def bitfield MATRIX_FUNCT6 <31:26>; +def bitfield MATRIX_LS <25>; +def bitfield MATRIX_IMM10 <24:15>; +def bitfield MATRIX_MA <11>; +def bitfield MATRIX_MD4 <10:7>; +def bitfield MATRIX_MS2_4 <23:20>; +def bitfield MATRIX_MS1_4 <18:15>; + //h instructions def bitfield HFUNCT2 <31:20>; diff --git a/src/arch/riscv/isa/decoder.isa b/src/arch/riscv/isa/decoder.isa index 8e670db70a..bed8e8466f 100644 --- a/src/arch/riscv/isa/decoder.isa +++ b/src/arch/riscv/isa/decoder.isa @@ -2438,15 +2438,15 @@ decode QUADRANT { 0x0: MatrixConfOp::mrelease({{ auto *isa = dynamic_cast(xc->tcBase()->getIsaPtr()); assert(isa != nullptr); - isa->matrixRelease(RS2); - }}, IsNonSpeculative, IsSerializeAfter, No_OpClass); + isa->matrixRelease(xc, RS2); + }}, IsNonSpeculative, IsSerializeAfter, IntAluOp); } 0x50: decode RD { 0x0: MatrixConfOp::macquire({{ auto *isa = dynamic_cast(xc->tcBase()->getIsaPtr()); assert(isa != nullptr); - isa->matrixAcquire(RS2, Rs1_ud); - }}, IsNonSpeculative, IsSerializeAfter, No_OpClass); + isa->matrixAcquire(xc, RS2, Rs1_ud); + }}, IsNonSpeculative, IsSerializeAfter, IsQuiesce, IntAluOp); } 0x11: decode RS2 { 0x0: decode FUNCT3 { @@ -2455,7 +2455,7 @@ decode QUADRANT { auto *isa = dynamic_cast(xc->tcBase()->getIsaPtr()); assert(isa != nullptr); isa->setMatrixTileM(Rs1_ud); - }}, No_OpClass); + }}, IntAluOp); } } } @@ -2466,7 +2466,7 @@ decode QUADRANT { auto *isa = dynamic_cast(xc->tcBase()->getIsaPtr()); assert(isa != nullptr); isa->setMatrixTileK(Rs1_ud); - }}, No_OpClass); + }}, IntAluOp); } } } @@ -2477,7 +2477,7 @@ decode QUADRANT { auto *isa = dynamic_cast(xc->tcBase()->getIsaPtr()); assert(isa != nullptr); isa->setMatrixTileN(Rs1_ud); - }}, No_OpClass); + }}, IntAluOp); } } } @@ -2488,8 +2488,8 @@ decode QUADRANT { 0x4: MatrixConfOp::mzero1r({{ auto *isa = dynamic_cast(xc->tcBase()->getIsaPtr()); assert(isa != nullptr); - isa->matrixZeroAcc(); - }}, No_OpClass); + isa->matrixZeroAcc(xc); + }}, IntAluOp); } } } @@ -2500,7 +2500,7 @@ decode QUADRANT { auto *isa = dynamic_cast(xc->tcBase()->getIsaPtr()); assert(isa != nullptr); return isa->matrixLoadA8(xc, Rs1_ud, Rs2_ud); - }}, No_OpClass); + }}, MemReadOp); } } 0x0a: decode FUNCT3 { @@ -2509,7 +2509,7 @@ decode QUADRANT { auto *isa = dynamic_cast(xc->tcBase()->getIsaPtr()); assert(isa != nullptr); return isa->matrixLoadB8(xc, Rs1_ud, Rs2_ud); - }}, No_OpClass); + }}, MemReadOp); } } 0x12: decode FUNCT3 { @@ -2518,7 +2518,7 @@ decode QUADRANT { auto *isa = dynamic_cast(xc->tcBase()->getIsaPtr()); assert(isa != nullptr); return isa->matrixLoadC32(xc, Rs1_ud, Rs2_ud); - }}, No_OpClass); + }}, MemReadOp); } } 0x13: decode FUNCT3 { @@ -2527,7 +2527,7 @@ decode QUADRANT { auto *isa = dynamic_cast(xc->tcBase()->getIsaPtr()); assert(isa != nullptr); return isa->matrixStoreC32(xc, Rs1_ud, Rs2_ud); - }}, No_OpClass); + }}, MemWriteOp); } } 0x0c: decode FUNCT3 { @@ -2537,14 +2537,129 @@ decode QUADRANT { 0x14: MatrixArithOp::mmacc_w_b({{ auto *isa = dynamic_cast(xc->tcBase()->getIsaPtr()); assert(isa != nullptr); - isa->matrixMMAccWB(); - }}, No_OpClass); + isa->matrixMMAccWB(xc); + }}, IntMultOp); } } } } } + 0x1d: decode MATRIX_FUNCT6 { + 0x00: decode FUNCT3 { + 0x2: decode MATRIX_LS { + 0x0: MatrixMemOp::legacy_mlce32({{ + auto *isa = dynamic_cast( + xc->tcBase()->getIsaPtr()); + assert(isa != nullptr); + return isa->matrixLoadC32( + xc, Rs1_ud, Rs2_ud, MATRIX_MD4); + }}, MemReadOp); + 0x1: MatrixMemOp::legacy_msce32({{ + auto *isa = dynamic_cast( + xc->tcBase()->getIsaPtr()); + assert(isa != nullptr); + return isa->matrixStoreC32( + xc, Rs1_ud, Rs2_ud, MATRIX_MD4); + }}, MemWriteOp); + } + 0x4: decode MATRIX_LS { + 0x1: MatrixConfOp::legacy_msettilem_i({{ + auto *isa = dynamic_cast( + xc->tcBase()->getIsaPtr()); + assert(isa != nullptr); + isa->setMatrixTileM(MATRIX_IMM10); + }}, IntAluOp); + } + 0x5: decode MATRIX_LS { + 0x1: MatrixConfOp::legacy_mcfg0_nop({{ + }}, IntAluOp); + } + } + 0x01: decode FUNCT3 { + 0x0: decode MATRIX_LS { + 0x0: MatrixMemOp::legacy_mlae8({{ + auto *isa = dynamic_cast( + xc->tcBase()->getIsaPtr()); + assert(isa != nullptr); + return isa->matrixLoadA8( + xc, Rs1_ud, Rs2_ud, MATRIX_MD4); + }}, MemReadOp); + } + 0x4: decode MATRIX_LS { + 0x1: MatrixConfOp::legacy_msettilem_i_hi({{ + auto *isa = dynamic_cast( + xc->tcBase()->getIsaPtr()); + assert(isa != nullptr); + isa->setMatrixTileM(MATRIX_IMM10); + }}, IntAluOp); + } + 0x5: decode MATRIX_LS { + 0x1: MatrixConfOp::legacy_msettilen_i({{ + auto *isa = dynamic_cast( + xc->tcBase()->getIsaPtr()); + assert(isa != nullptr); + isa->setMatrixTileN(MATRIX_IMM10); + }}, IntAluOp); + } + 0x6: decode MATRIX_LS { + 0x1: MatrixConfOp::legacy_msettilek_i({{ + auto *isa = dynamic_cast( + xc->tcBase()->getIsaPtr()); + assert(isa != nullptr); + isa->setMatrixTileK(MATRIX_IMM10); + }}, IntAluOp); + } + } + 0x02: decode FUNCT3 { + 0x0: decode MATRIX_LS { + 0x0: MatrixMemOp::legacy_mlbe8({{ + auto *isa = dynamic_cast( + xc->tcBase()->getIsaPtr()); + assert(isa != nullptr); + return isa->matrixLoadB8( + xc, Rs1_ud, Rs2_ud, MATRIX_MD4); + }}, MemReadOp); + } + } + 0x0a: decode FUNCT3 { + 0x4: decode MATRIX_MA { + 0x1: MatrixArithOp::legacy_mmacc_w_b({{ + auto *isa = dynamic_cast( + xc->tcBase()->getIsaPtr()); + assert(isa != nullptr); + isa->matrixMMAccWB(xc, MATRIX_MS1_4, + MATRIX_MS2_4, MATRIX_MD4); + }}, IntMultOp); + } + } + 0x30: decode RD { + 0x0: MatrixConfOp::legacy_msyncreset({{ + auto *isa = dynamic_cast( + xc->tcBase()->getIsaPtr()); + assert(isa != nullptr); + isa->matrixSyncReset(RS2); + }}, IsNonSpeculative, IsSerializeAfter, No_OpClass); + } + 0x31: decode RD { + 0x0: MatrixConfOp::legacy_mrelease({{ + auto *isa = dynamic_cast( + xc->tcBase()->getIsaPtr()); + assert(isa != nullptr); + isa->matrixRelease(xc, RS2); + }}, IsNonSpeculative, IsSerializeAfter, IntAluOp); + } + 0x32: decode RD { + 0x0: MatrixConfOp::legacy_macquire({{ + auto *isa = dynamic_cast( + xc->tcBase()->getIsaPtr()); + assert(isa != nullptr); + isa->matrixAcquire(xc, RS2, Rs1_ud); + }}, IsNonSpeculative, IsSerializeAfter, IsQuiesce, + IntAluOp); + } + } + 0x1e: M5Op::M5Op(); 0x1a: NemuOp::NemuOp(); diff --git a/src/arch/riscv/isa/includes.isa b/src/arch/riscv/isa/includes.isa index 4388c8df2d..32c79e5374 100644 --- a/src/arch/riscv/isa/includes.isa +++ b/src/arch/riscv/isa/includes.isa @@ -50,7 +50,6 @@ output header {{ #include "arch/riscv/decoder.hh" #include "arch/riscv/insts/amo.hh" #include "arch/riscv/insts/compressed.hh" -#include "arch/riscv/insts/matrix.hh" #include "arch/riscv/insts/mem.hh" #include "arch/riscv/insts/pseudo.hh" #include "arch/riscv/insts/standard.hh" @@ -100,7 +99,6 @@ output exec {{ #include "arch/generic/memhelpers.hh" #include "arch/riscv/faults.hh" #include "arch/riscv/fp_inst.hh" -#include "arch/riscv/insts/matrix.hh" #include "arch/riscv/mmu.hh" #include "arch/riscv/reg_abi.hh" #include "arch/riscv/regs/float.hh" diff --git a/src/arch/riscv/regs/misc.hh b/src/arch/riscv/regs/misc.hh index 058942e135..44acff698b 100644 --- a/src/arch/riscv/regs/misc.hh +++ b/src/arch/riscv/regs/misc.hh @@ -265,6 +265,28 @@ enum MiscRegIndex // non-maskable-interrupt-pending: NMI version of xIP MISCREG_NMIP, + MISCREG_MCSR, + MISCREG_MXRM, + MISCREG_MSAT, + MISCREG_MFFLAGS, + MISCREG_MFRM, + MISCREG_MSATEN, + MISCREG_MTYPE, + MISCREG_MTILEM, + MISCREG_MTILEN, + MISCREG_MTILEK, + MISCREG_MLENB, + MISCREG_MRLENB, + MISCREG_MAMUL, + MISCREG_XMISA, + MISCREG_XTLENB, + MISCREG_XTRLENB, + MISCREG_XALENB, + MISCREG_MTOK, + MISCREG_XMTILEM, + MISCREG_XMTILEN, + MISCREG_XMTILEK, + NUM_MISCREGS }; @@ -463,6 +485,28 @@ enum CSRIndex CSR_VTYPE = 0xC21, CSR_VLENB = 0xC22, + CSR_MCSR = 0x802, + CSR_MXRM = 0x806, + CSR_MSAT = 0x807, + CSR_MFFLAGS = 0x808, + CSR_MFRM = 0x809, + CSR_MSATEN = 0x80A, + CSR_MTYPE = 0xC40, + CSR_MTILEM_LEGACY = 0xC41, + CSR_MTILEN_LEGACY = 0xC42, + CSR_MTILEK_LEGACY = 0xC43, + CSR_MLENB = 0xC44, + CSR_MRLENB = 0xC45, + CSR_MAMUL = 0xC46, + CSR_XMISA = 0xCC0, + CSR_XTLENB = 0xCC1, + CSR_XTRLENB = 0xCC2, + CSR_XALENB = 0xCC3, + CSR_MTOK = 0xCC4, + CSR_XMTILEM = 0xCC5, + CSR_XMTILEN = 0xCC6, + CSR_XMTILEK = 0xCC7, + CSR_HSTATUS = 0x600, CSR_HEDELEG = 0x602, CSR_HIDELEG = 0x603, @@ -689,6 +733,28 @@ const std::map CSRData = { {CSR_VTYPE, {"vtype" , MISCREG_VTYPE}}, {CSR_VLENB, {"VLENB" , MISCREG_VLENB}}, + {CSR_MCSR, {"mcsr", MISCREG_MCSR}}, + {CSR_MXRM, {"mxrm", MISCREG_MXRM}}, + {CSR_MSAT, {"msat", MISCREG_MSAT}}, + {CSR_MFFLAGS, {"mfflags", MISCREG_MFFLAGS}}, + {CSR_MFRM, {"mfrm", MISCREG_MFRM}}, + {CSR_MSATEN, {"msaten", MISCREG_MSATEN}}, + {CSR_MTYPE, {"mtype", MISCREG_MTYPE}}, + {CSR_MTILEM_LEGACY, {"mtilem", MISCREG_MTILEM}}, + {CSR_MTILEN_LEGACY, {"mtilen", MISCREG_MTILEN}}, + {CSR_MTILEK_LEGACY, {"mtilek", MISCREG_MTILEK}}, + {CSR_MLENB, {"mlenb", MISCREG_MLENB}}, + {CSR_MRLENB, {"mrlenb", MISCREG_MRLENB}}, + {CSR_MAMUL, {"mamul", MISCREG_MAMUL}}, + {CSR_XMISA, {"xmisa", MISCREG_XMISA}}, + {CSR_XTLENB, {"xtlenb", MISCREG_XTLENB}}, + {CSR_XTRLENB, {"xtrlenb", MISCREG_XTRLENB}}, + {CSR_XALENB, {"xalenb", MISCREG_XALENB}}, + {CSR_MTOK, {"mtok", MISCREG_MTOK}}, + {CSR_XMTILEM, {"xmtilem", MISCREG_XMTILEM}}, + {CSR_XMTILEN, {"xmtilen", MISCREG_XMTILEN}}, + {CSR_XMTILEK, {"xmtilek", MISCREG_XMTILEK}}, + {CSR_HSTATUS, {"hstatus", MISCREG_HSTATUS}}, {CSR_HEDELEG, {"hedeleg", MISCREG_HEDELEG}}, {CSR_HIDELEG, {"hideleg", MISCREG_HIDELEG}}, @@ -917,6 +983,15 @@ const RegVal NEMU_MIP_MASK = ((1 << 9) | (1 << 5) | (1 << 2) |(1 << 1)); const RegVal SI_MASK = SEI_MASK | STI_MASK | SSI_MASK; const RegVal FFLAGS_MASK = (1 << FRM_OFFSET) - 1; const RegVal FRM_MASK = 0x7; +const RegVal MCSR_MXRM_MASK = 0x3; +const RegVal MCSR_MSAT_MASK = 0x1; +const RegVal MCSR_MFFLAGS_MASK = 0x1f; +const RegVal MCSR_MFRM_MASK = 0x7; +const RegVal MCSR_MSATEN_MASK = 0x1; +const RegVal MCSR_MASK = MCSR_MXRM_MASK | (MCSR_MSAT_MASK << 2) | + (MCSR_MFFLAGS_MASK << 3) | + (MCSR_MFRM_MASK << 8) | + (MCSR_MSATEN_MASK << 11); const RegVal NEMU_MIE_MASK_BASE = 0xaaa; const RegVal NEMU_MIE_MASK_H = (1 << 2) | (1 << 6) | (1 << 10) | (1 << 12); const RegVal NEMU_LCOFI = 0; @@ -929,7 +1004,13 @@ const std::map CSRMasks = { {CSR_SSTATUS, SSTATUS_MASK}, {CSR_SIP, SI_MASK}, {CSR_MISA, MISA_MASK}, - {CSR_MIE,NEMU_MIE_MASK} + {CSR_MIE,NEMU_MIE_MASK}, + {CSR_MCSR, MCSR_MASK}, + {CSR_MXRM, MCSR_MXRM_MASK}, + {CSR_MSAT, MCSR_MSAT_MASK}, + {CSR_MFFLAGS, MCSR_MFFLAGS_MASK}, + {CSR_MFRM, MCSR_MFRM_MASK}, + {CSR_MSATEN, MCSR_MSATEN_MASK} }; #define concat_temp(x, y) x ## y diff --git a/src/matrix/SConscript b/src/matrix/SConscript new file mode 100644 index 0000000000..6a83c79177 --- /dev/null +++ b/src/matrix/SConscript @@ -0,0 +1,6 @@ +Import('*') + +Source('matrix_controller.cc') +GTest('matrix_controller.test', 'matrix_controller.test.cc', + 'matrix_controller.cc', '../sim/cur_tick.cc', '../base/debug.cc', + '../base/str.cc', 'matrix_stats_test_stub.cc') diff --git a/src/matrix/matrix_controller.cc b/src/matrix/matrix_controller.cc new file mode 100644 index 0000000000..3e2f9080ea --- /dev/null +++ b/src/matrix/matrix_controller.cc @@ -0,0 +1,2079 @@ +/* + * XSAI CUTE-aligned matrix controller scaffold. + */ + +#include "matrix/matrix_controller.hh" + +#include +#include +#include + +#include "base/logging.hh" +#include "cpu/exec_context.hh" +#include "sim/cur_tick.hh" + +#ifndef UNIT_TEST +#include "arch/generic/mmu.hh" +#include "cpu/base.hh" +#include "cpu/thread_context.hh" +#include "mem/se_translating_port_proxy.hh" +#include "mem/translating_port_proxy.hh" +#include "sim/full_system.hh" +#endif + +namespace gem5 +{ + +namespace matrix +{ + +namespace +{ + +#ifndef UNIT_TEST +std::string +indexedName(const char *base, size_t idx) +{ + return std::string(base) + std::to_string(idx); +} + +bool +checkpointEntryExists(CheckpointIn &cp, const std::string &name) +{ + return cp.entryExists(Serializable::currentSection(), name); +} +#endif + +#ifndef UNIT_TEST +Fault +matrixReadBlob(ThreadContext *tc, Addr addr, void *dst, size_t size) +{ + bool ok = false; + if (FullSystem) { + TranslatingPortProxy proxy(tc); + ok = proxy.tryReadBlob(addr, dst, size); + } else { + SETranslatingPortProxy proxy(tc); + ok = proxy.tryReadBlob(addr, dst, size); + } + + if (!ok) { + return std::make_shared(addr); + } + return NoFault; +} + +Fault +matrixWriteBlob(ThreadContext *tc, Addr addr, const void *src, size_t size) +{ + bool ok = false; + if (FullSystem) { + TranslatingPortProxy proxy(tc); + ok = proxy.tryWriteBlob(addr, src, size); + } else { + SETranslatingPortProxy proxy(tc); + ok = proxy.tryWriteBlob(addr, src, size); + } + + if (!ok) { + return std::make_shared(addr); + } + return NoFault; +} +#endif + +uint8_t +selectHighestFreeSource( + const std::array &busy) +{ + for (uint32_t i = MatrixController::LocalMmuSourceCount; i > 0; --i) { + const uint32_t source = i - 1; + if (!busy[source]) { + return static_cast(source); + } + } + return MatrixController::LocalMmuSourceCount; +} + +Tick +peekIssueSlot(Tick earliest_tick, Tick issue_slot_tick, + uint32_t issue_slots_used, uint32_t issue_per_cycle, Tick cycle_ticks) +{ + const uint32_t slots_per_cycle = std::max(issue_per_cycle, 1); + const Tick cycle = std::max(cycle_ticks, 1); + Tick issue_tick = std::max(earliest_tick, issue_slot_tick); + uint32_t slots_used = issue_slots_used; + if (issue_tick > issue_slot_tick) { + slots_used = 0; + } + if (slots_used >= slots_per_cycle) { + issue_tick = std::max(issue_tick, issue_slot_tick + cycle); + } + return issue_tick; +} + +Tick +reserveIssueSlot(Tick earliest_tick, Tick &issue_slot_tick, + uint32_t &issue_slots_used, uint32_t issue_per_cycle, Tick cycle_ticks) +{ + const uint32_t slots_per_cycle = std::max(issue_per_cycle, 1); + const Tick cycle = std::max(cycle_ticks, 1); + Tick issue_tick = std::max(earliest_tick, issue_slot_tick); + if (issue_tick > issue_slot_tick) { + issue_slot_tick = issue_tick; + issue_slots_used = 0; + } + if (issue_slots_used >= slots_per_cycle) { + issue_slot_tick += cycle; + issue_slots_used = 0; + issue_tick = std::max(issue_tick, issue_slot_tick); + } + ++issue_slots_used; + return issue_tick; +} + +} // namespace + +MatrixController::Stats::Stats(statistics::Group *parent) + : statistics::Group(parent, nullptr), + ADD_STAT(tasksAccepted, statistics::units::Count::get(), + "Matrix controller tasks accepted"), + ADD_STAT(tasksIssued, statistics::units::Count::get(), + "Matrix controller tasks issued"), + ADD_STAT(tasksCompleted, statistics::units::Count::get(), + "Matrix controller tasks completed"), + ADD_STAT(tasksAborted, statistics::units::Count::get(), + "Matrix controller tasks aborted"), + ADD_STAT(aPortTasks, statistics::units::Count::get(), + "Matrix A-port tasks"), + ADD_STAT(bPortTasks, statistics::units::Count::get(), + "Matrix B-port tasks"), + ADD_STAT(cLoadTasks, statistics::units::Count::get(), + "Matrix C-load tasks"), + ADD_STAT(cStoreTasks, statistics::units::Count::get(), + "Matrix C-store tasks"), + ADD_STAT(mmaTasks, statistics::units::Count::get(), + "Matrix MMA tasks"), + ADD_STAT(zeroTasks, statistics::units::Count::get(), + "Matrix zero tasks"), + ADD_STAT(releaseTasks, statistics::units::Count::get(), + "Matrix release tasks"), + ADD_STAT(memoryRequests, statistics::units::Count::get(), + "Analytic LocalMMU requests"), + ADD_STAT(memoryReadRequests, statistics::units::Count::get(), + "Analytic LocalMMU read requests"), + ADD_STAT(memoryWriteRequests, statistics::units::Count::get(), + "Analytic LocalMMU write requests"), + ADD_STAT(memoryBytes, statistics::units::Byte::get(), + "Functional matrix memory bytes"), + ADD_STAT(memoryReadBytes, statistics::units::Byte::get(), + "Functional matrix memory read bytes"), + ADD_STAT(memoryWriteBytes, statistics::units::Byte::get(), + "Functional matrix memory write bytes"), + ADD_STAT(memoryBusBytes, statistics::units::Byte::get(), + "Analytic matrix memory bus bytes"), + ADD_STAT(memoryReadBusBytes, statistics::units::Byte::get(), + "Analytic matrix memory read bus bytes"), + ADD_STAT(memoryWriteBusBytes, statistics::units::Byte::get(), + "Analytic matrix memory write bus bytes"), + ADD_STAT(localMmuSourceAllocations, statistics::units::Count::get(), + "Analytic LocalMMU source allocations"), + ADD_STAT(localMmuSourceReleases, statistics::units::Count::get(), + "Analytic LocalMMU source releases"), + ADD_STAT(localMmuArbitrations, statistics::units::Count::get(), + "Analytic LocalMMU port arbitrations"), + ADD_STAT(localMmuReadDataResponses, statistics::units::Count::get(), + "Analytic LocalMMU read responses"), + ADD_STAT(localMmuWriteAcks, statistics::units::Count::get(), + "Analytic LocalMMU write acknowledgements"), + ADD_STAT(localMmuMaxOutstanding, statistics::units::Count::get(), + "Maximum analytic LocalMMU outstanding sources"), + ADD_STAT(memoryPipelineRequests, statistics::units::Count::get(), + "CUTE memory pipeline requests launched toward L2"), + ADD_STAT(memoryPipelineReadResponses, statistics::units::Count::get(), + "CUTE memory pipeline read data responses"), + ADD_STAT(memoryPipelineWriteAcks, statistics::units::Count::get(), + "CUTE memory pipeline write acknowledgements"), + ADD_STAT(memoryPipelineSourceStallTicks, statistics::units::Tick::get(), + "Ticks CUTE memory requests waited for a free source id"), + ADD_STAT(memoryPipelineRequestQueueTicks, statistics::units::Tick::get(), + "Ticks CUTE memory requests waited for LocalMMU/TL-A issue slots"), + ADD_STAT(memoryPipelineResponseQueueTicks, statistics::units::Tick::get(), + "Ticks CUTE memory responses waited for the unified response port"), + ADD_STAT(memoryPipelineLastRequestTick, statistics::units::Tick::get(), + "Last CUTE memory pipeline request issue tick"), + ADD_STAT(memoryPipelineLastResponseTick, statistics::units::Tick::get(), + "Last CUTE memory pipeline response tick"), + ADD_STAT(memoryPipelineMaxOutstanding, statistics::units::Count::get(), + "Maximum CUTE memory pipeline outstanding sources"), + ADD_STAT(timingTasks, statistics::units::Count::get(), + "Matrix tasks scheduled by the timing model"), + ADD_STAT(timingQueueTicks, statistics::units::Tick::get(), + "Matrix task queueing ticks before issue"), + ADD_STAT(timingMaxQueueTicks, statistics::units::Tick::get(), + "Maximum matrix task queueing ticks before issue"), + ADD_STAT(timingBusyTicks, statistics::units::Tick::get(), + "Analytic matrix engine busy ticks"), + ADD_STAT(timingLastIssueTick, statistics::units::Tick::get(), + "Last analytic matrix issue tick"), + ADD_STAT(timingLastCompletionTick, statistics::units::Tick::get(), + "Last analytic matrix completion tick"), + ADD_STAT(acquireStallEvents, statistics::units::Count::get(), + "macquire events that quiesced the CPU"), + ADD_STAT(acquireStallTicks, statistics::units::Tick::get(), + "Total macquire quiesce ticks"), + ADD_STAT(tokenReleaseEvents, statistics::units::Count::get(), + "Matrix token release events scheduled"), + ADD_STAT(tokenReleaseDelayEvents, statistics::units::Count::get(), + "Matrix token release events delayed by observed L2 responses"), + ADD_STAT(tokenReleaseDelayTicks, statistics::units::Tick::get(), + "Total ticks added to pending matrix token release events by " + "observed L2 responses") +{ +} + +void +MatrixController::Stats::resetStats() +{ + statistics::Group::resetStats(); + memPortRequests.fill(0); + memPortReadDataResponses.fill(0); + memPortWriteAcks.fill(0); + localMmuPortSelections.fill(0); + taskEvents.fill(0); +} + +MatrixController::MatrixController(statistics::Group *parent) + : stats(parent) +{ + reset(); +} + +void +MatrixController::reset() +{ + resetDataState(); + resetControlState(); + resetTimingState(); + stats.resetStats(); +} + +void +MatrixController::setTimingConfig(const TimingConfig &config) +{ + timingConfig = config; + timingConfig.issueIntervalCycles = + std::max(timingConfig.issueIntervalCycles, 1); + timingConfig.localMmuIssuePerCycle = + std::max(timingConfig.localMmuIssuePerCycle, 1); + timingConfig.l2ResponsePipelineCycles = + std::max(timingConfig.l2ResponsePipelineCycles, 1); +} + +void +MatrixController::retireReadyTokensUpTo(Tick now) +{ + retireReadyTokens(now); +} + +bool +MatrixController::tokenTargetReached( + uint64_t token_idx, uint64_t target) const +{ + return token(token_idx) >= target; +} + +Tick +MatrixController::tokenReadyTick(uint64_t token_idx, uint64_t target) const +{ + return tokenTargetReadyTick(token_idx, target); +} + +void +MatrixController::recordAcquireStall(Tick ticks) +{ + ++stats.acquireStallEvents; + stats.acquireStallTicks += ticks; +} + +void +MatrixController::resetDataState() +{ + data.tileM = 0; + data.tileK = 0; + data.tileN = 0; + for (auto ® : data.abRegs) { + reg.assign(MatrixABRegBytes, 0); + } + for (auto ® : data.accRegs) { + reg.assign(MatrixAccElems, 0); + } + data.tokens.assign(TokenCount, 0); +} + +void +MatrixController::resetControlState() +{ + control = ControlState{}; + control.sourceToPort.fill(MemPort::Num); + control.sourceBusy.fill(false); +} + +void +MatrixController::resetTimingState() +{ + timing = TimingState{}; +} + +MatrixController::ControlSnapshot +MatrixController::controlSnapshot() const +{ + ControlSnapshot snapshot; + snapshot.decodedQueueHead = control.queueHead; + snapshot.decodedQueueSize = control.queueSize; + snapshot.nextLoadFifoIdx = control.nextLoadFifoIdx; + snapshot.nextComputeFifoIdx = control.nextComputeFifoIdx; + snapshot.nextStoreFifoIdx = control.nextStoreFifoIdx; + + for (size_t i = 0; i < FuTypeCount; ++i) { + if (control.fus[i].busy) { + snapshot.fuBusyMask |= 1U << i; + } + } + + for (size_t i = 0; i < MatrixRegCount; ++i) { + if (control.abRegs[i].busy) { + snapshot.abBusyMask |= 1U << i; + } + if (control.cRegs[i].busy) { + snapshot.cBusyMask |= 1U << i; + } + + snapshot.abPendingReaders[i] = control.abRegs[i].pendingReaders; + snapshot.cPendingReaders[i] = control.cRegs[i].pendingReaders; + if (control.abRegs[i].pendingReaders != 0) { + snapshot.abPendingReaderMask |= 1U << i; + } + if (control.cRegs[i].pendingReaders != 0) { + snapshot.cPendingReaderMask |= 1U << i; + } + } + + snapshot.pendingStore = control.pendingStore; + snapshot.firstMmuRequestIndex = control.firstMmuRequestIndex; + snapshot.nextLocalMmuSource = selectHighestFreeSource( + control.sourceBusy); + snapshot.localMmuOutstanding = control.localMmuOutstanding; + for (size_t i = 0; i < LocalMmuSourceCount; ++i) { + if (control.sourceBusy[i]) { + snapshot.localMmuBusySourceMask |= 1ULL << i; + } + } + snapshot.timingNextIssueTick = timing.nextIssueTick; + snapshot.timingPendingStoreReadyTick = timing.pendingStoreReadyTick; + snapshot.timingLastIssueTick = timing.lastIssueTick; + snapshot.timingLastCompletionTick = timing.lastCompletionTick; + snapshot.timingLocalMmuIssueSlotTick = timing.localMmuIssueSlotTick; + snapshot.timingLocalMmuResponseSlotTick = + timing.localMmuResponseSlotTick; + snapshot.timingLocalMmuLastRequestTick = + timing.localMmuLastRequestTick; + snapshot.timingLocalMmuLastResponseTick = + timing.localMmuLastResponseTick; + snapshot.timingLocalMmuOutstanding = + timingLocalMmuOutstanding(curTick()); + snapshot.pendingTokenEvents = pendingTokenEvents(); + + return snapshot; +} + +#ifdef UNIT_TEST +uint8_t +MatrixController::allocateLocalMmuSourceForTest(MemPort port) +{ + return allocateLocalMmuSource(port); +} + +void +MatrixController::releaseLocalMmuSourceForTest(uint8_t source) +{ + releaseLocalMmuSource(source); +} + +void +MatrixController::writeABRegForTest( + uint64_t reg_idx, uint32_t row, uint32_t col, int8_t value) +{ + panic_if(row >= MatrixMaxM || col >= MatrixMaxK, + "matrix AB test write out of range"); + abReg(reg_idx)[row * MatrixMaxK + col] = value; +} + +int32_t +MatrixController::readAccRegForTest( + uint64_t reg_idx, uint32_t row, uint32_t col) const +{ + panic_if(row >= MatrixMaxM || col >= MatrixMaxN, + "matrix accumulator test read out of range"); + return accReg(reg_idx)[row * MatrixMaxN + col]; +} + +RegVal +MatrixController::readTokenForTest(uint64_t token_idx) const +{ + return token(token_idx); +} + +Tick +MatrixController::tokenReadyTickForTest( + uint64_t token_idx, uint64_t target) const +{ + return tokenTargetReadyTick(token_idx, target); +} + +void +MatrixController::scheduleMemoryTimingForTest( + MemPort port, bool is_store, uint64_t requests) +{ + TaskDesc task; + task.op = TaskOp::LoadStore; + task.memPort = port; + task.isStore = is_store; + task.memoryRequests = requests; + task.memoryBytes = requests * OutsideDataWidthBytes; + task.memoryBusBytes = task.memoryBytes; + task.base = 0x1000; + task.stride = OutsideDataWidthBytes; + task.rows = static_cast(requests); + task.cols = OutsideDataWidthBytes; + task.bytesPerRow = OutsideDataWidthBytes; + task.fu = port == MemPort::A ? FuType::AML : + port == MemPort::B ? FuType::BML : + port == MemPort::CLoad ? FuType::CMLLoad : + port == MemPort::CStore ? FuType::CMLStore : + FuType::None; + task.destValid = !is_store; + task.destIsAcc = port == MemPort::CLoad; + task.destReg = 0; + task.srcValid[0] = is_store; + task.srcIsAcc[0] = port == MemPort::CStore; + task.srcReg[0] = 0; + task.srcReadPending[0] = is_store; + + beginTask(task); + scheduleTimingTask(nullptr, task); + completeTask(); +} + +void +MatrixController::scheduleMemoryTimingForTest(MemPort port, bool is_store, + Addr base, Addr stride, uint32_t rows, uint32_t cols, + ElemWidth elem_width, bool transpose) +{ + const FuType fu = port == MemPort::A ? FuType::AML : + port == MemPort::B ? FuType::BML : + port == MemPort::CLoad ? FuType::CMLLoad : + port == MemPort::CStore ? FuType::CMLStore : + FuType::None; + const bool is_acc = port == MemPort::CLoad || port == MemPort::CStore; + TaskDesc task = makeMemoryTask(port, fu, is_store, is_acc, base, stride, + rows, cols, elem_width, transpose, 0); + + beginTask(task); + scheduleTimingTask(nullptr, task); + completeTask(); +} +#endif + +void +MatrixController::normalizeDataState() +{ + data.tileM = clampTileM(data.tileM); + data.tileK = clampTileK(data.tileK); + data.tileN = clampTileN(data.tileN); + for (auto ® : data.abRegs) { + reg.resize(MatrixABRegBytes, 0); + } + for (auto ® : data.accRegs) { + reg.resize(MatrixAccElems, 0); + } + data.tokens.resize(TokenCount, 0); +} + +void +MatrixController::serialize(CheckpointOut &cp) const +{ +#ifdef UNIT_TEST + panic("MatrixController checkpoint serialization is not linked in unit " + "tests"); +#else + paramOut(cp, "matrixTileM", data.tileM); + paramOut(cp, "matrixTileK", data.tileK); + paramOut(cp, "matrixTileN", data.tileN); + + for (size_t i = 0; i < MatrixRegCount; ++i) { + arrayParamOut(cp, indexedName("matrixABReg", i), data.abRegs[i]); + arrayParamOut(cp, indexedName("matrixAccReg", i), data.accRegs[i]); + } + + arrayParamOut(cp, "matrixTileA", data.abRegs[DefaultAReg]); + arrayParamOut(cp, "matrixTileB", data.abRegs[DefaultBReg]); + arrayParamOut(cp, "matrixAcc", data.accRegs[DefaultAccReg]); + arrayParamOut(cp, "matrixTokens", data.tokens); + + paramOut(cp, "matrixTimingNextIssueTick", timing.nextIssueTick); + paramOut(cp, "matrixTimingPendingStoreReadyTick", + timing.pendingStoreReadyTick); + paramOut(cp, "matrixTimingLastIssueTick", timing.lastIssueTick); + paramOut(cp, "matrixTimingLastCompletionTick", + timing.lastCompletionTick); + paramOut(cp, "matrixTimingLocalMmuIssueSlotTick", + timing.localMmuIssueSlotTick); + paramOut(cp, "matrixTimingLocalMmuIssueSlotsUsed", + timing.localMmuIssueSlotsUsed); + paramOut(cp, "matrixTimingLocalMmuResponseSlotTick", + timing.localMmuResponseSlotTick); + paramOut(cp, "matrixTimingLocalMmuLastRequestTick", + timing.localMmuLastRequestTick); + paramOut(cp, "matrixTimingLocalMmuLastResponseTick", + timing.localMmuLastResponseTick); + arrayParamOut(cp, "matrixTimingLocalMmuSourceReadyTicks", + timing.localMmuSourceReadyTick); + for (size_t i = 0; i < TokenCount; ++i) { + arrayParamOut(cp, indexedName("matrixTokenReadyTicks", i), + timing.tokenReadyTicks[i]); + } +#endif +} + +void +MatrixController::unserialize(CheckpointIn &cp) +{ +#ifdef UNIT_TEST + panic("MatrixController checkpoint unserialization is not linked in unit " + "tests"); +#else + paramIn(cp, "matrixTileM", data.tileM); + paramIn(cp, "matrixTileK", data.tileK); + paramIn(cp, "matrixTileN", data.tileN); + + for (auto ® : data.abRegs) { + reg.assign(MatrixABRegBytes, 0); + } + for (auto ® : data.accRegs) { + reg.assign(MatrixAccElems, 0); + } + + bool found_new_regs = false; + for (size_t i = 0; i < MatrixRegCount; ++i) { + const std::string ab_name = indexedName("matrixABReg", i); + if (checkpointEntryExists(cp, ab_name)) { + arrayParamIn(cp, ab_name, data.abRegs[i]); + found_new_regs = true; + } + + const std::string acc_name = indexedName("matrixAccReg", i); + if (checkpointEntryExists(cp, acc_name)) { + arrayParamIn(cp, acc_name, data.accRegs[i]); + found_new_regs = true; + } + } + + if (!found_new_regs) { + arrayParamIn(cp, "matrixTileA", data.abRegs[DefaultAReg]); + arrayParamIn(cp, "matrixTileB", data.abRegs[DefaultBReg]); + arrayParamIn(cp, "matrixAcc", data.accRegs[DefaultAccReg]); + } + arrayParamIn(cp, "matrixTokens", data.tokens); + + normalizeDataState(); + resetControlState(); + resetTimingState(); + if (checkpointEntryExists(cp, "matrixTimingNextIssueTick")) { + paramIn(cp, "matrixTimingNextIssueTick", timing.nextIssueTick); + } + if (checkpointEntryExists(cp, "matrixTimingPendingStoreReadyTick")) { + paramIn(cp, "matrixTimingPendingStoreReadyTick", + timing.pendingStoreReadyTick); + } + if (checkpointEntryExists(cp, "matrixTimingLastIssueTick")) { + paramIn(cp, "matrixTimingLastIssueTick", timing.lastIssueTick); + } + if (checkpointEntryExists(cp, "matrixTimingLastCompletionTick")) { + paramIn(cp, "matrixTimingLastCompletionTick", + timing.lastCompletionTick); + } + if (checkpointEntryExists(cp, "matrixTimingLocalMmuIssueSlotTick")) { + paramIn(cp, "matrixTimingLocalMmuIssueSlotTick", + timing.localMmuIssueSlotTick); + } + if (checkpointEntryExists(cp, "matrixTimingLocalMmuIssueSlotsUsed")) { + paramIn(cp, "matrixTimingLocalMmuIssueSlotsUsed", + timing.localMmuIssueSlotsUsed); + } + if (checkpointEntryExists(cp, "matrixTimingLocalMmuResponseSlotTick")) { + paramIn(cp, "matrixTimingLocalMmuResponseSlotTick", + timing.localMmuResponseSlotTick); + } + if (checkpointEntryExists(cp, "matrixTimingLocalMmuLastRequestTick")) { + paramIn(cp, "matrixTimingLocalMmuLastRequestTick", + timing.localMmuLastRequestTick); + } + if (checkpointEntryExists(cp, "matrixTimingLocalMmuLastResponseTick")) { + paramIn(cp, "matrixTimingLocalMmuLastResponseTick", + timing.localMmuLastResponseTick); + } + if (checkpointEntryExists(cp, "matrixTimingLocalMmuSourceReadyTicks")) { + arrayParamIn(cp, "matrixTimingLocalMmuSourceReadyTicks", + timing.localMmuSourceReadyTick.data(), + timing.localMmuSourceReadyTick.size()); + } + for (size_t i = 0; i < TokenCount; ++i) { + const std::string token_ready_name = + indexedName("matrixTokenReadyTicks", i); + if (checkpointEntryExists(cp, token_ready_name)) { + arrayParamIn(cp, token_ready_name, timing.tokenReadyTicks[i]); + } + } +#endif +} + +RegVal & +MatrixController::token(uint64_t idx) +{ + panic_if(idx >= data.tokens.size(), + "matrix token index %llu out of range", + static_cast(idx)); + return data.tokens[idx]; +} + +const RegVal & +MatrixController::token(uint64_t idx) const +{ + panic_if(idx >= data.tokens.size(), + "matrix token index %llu out of range", + static_cast(idx)); + return data.tokens[idx]; +} + +std::vector & +MatrixController::abReg(uint64_t idx) +{ + return data.abRegs[normalizeRegIdx(idx)]; +} + +const std::vector & +MatrixController::abReg(uint64_t idx) const +{ + return data.abRegs[normalizeRegIdx(idx)]; +} + +std::vector & +MatrixController::accReg(uint64_t idx) +{ + return data.accRegs[normalizeRegIdx(idx)]; +} + +const std::vector & +MatrixController::accReg(uint64_t idx) const +{ + return data.accRegs[normalizeRegIdx(idx)]; +} + +void +MatrixController::syncReset(uint64_t token_idx) +{ + token(token_idx) = 0; + timing.tokenReadyTicks[token_idx].clear(); +} + +void +MatrixController::release(ExecContext *xc, uint64_t token_idx) +{ + TaskDesc task = makeReleaseTask(token_idx); + beginTask(task); + const Tick ready_tick = scheduleTimingTask(xc, task); + enqueueTokenRelease(token_idx, ready_tick); + completeTask(); +} + +void +MatrixController::acquire(ExecContext *xc, uint64_t token_idx, uint64_t target) +{ +#ifdef UNIT_TEST + panic("MatrixController::acquire requires a ThreadContext in unit tests"); +#else + ThreadContext *tc = xc->tcBase(); + const Tick now = curTick(); + retireReadyTokens(now); + + const RegVal observed = token(token_idx); + if (observed >= target) { + return; + } + + const Tick ready_tick = tokenTargetReadyTick(token_idx, target); + panic_if(ready_tick == MaxTick, + "macquire tok%llu target=%llu observed=%llu has no pending release", + static_cast(token_idx), + static_cast(target), + static_cast(observed)); + + if (ready_tick > now) { + ++stats.acquireStallEvents; + stats.acquireStallTicks += ready_tick - now; + tc->quiesceTick(ready_tick); + } else { + retireReadyTokens(ready_tick); + } +#endif +} + +void +MatrixController::setTileM(uint64_t value) +{ + data.tileM = clampTileM(value); +} + +void +MatrixController::setTileK(uint64_t value) +{ + data.tileK = clampTileK(value); +} + +void +MatrixController::setTileN(uint64_t value) +{ + data.tileN = clampTileN(value); +} + +Fault +MatrixController::load(ExecContext *xc, MemPort port, Addr base, Addr stride, + uint32_t rows, uint32_t cols, ElemWidth width, bool transpose, + uint64_t reg_idx) +{ + panic_if(port != MemPort::A && port != MemPort::B && + port != MemPort::CLoad, "unsupported matrix load port"); + + const bool is_acc = port == MemPort::CLoad; + const FuType fu = port == MemPort::A ? FuType::AML : + port == MemPort::B ? FuType::BML : FuType::CMLLoad; + const uint8_t reg = normalizeRegIdx(reg_idx); + TaskDesc task = makeMemoryTask(port, fu, false, is_acc, base, stride, + rows, cols, width, transpose, reg); + + beginTask(task); + Fault fault = executeLoad(xc, task); + if (fault != NoFault) { + abortTask(); + return fault; + } + + scheduleTimingTask(xc, task); + completeTask(); + return NoFault; +} + +Fault +MatrixController::store(ExecContext *xc, MemPort port, Addr base, Addr stride, + uint32_t rows, uint32_t cols, ElemWidth width, bool transpose, + uint64_t reg_idx) +{ + panic_if(port != MemPort::CStore, + "matrix store currently models the RTL CML-store path only"); + + const bool is_acc = true; + const FuType fu = FuType::CMLStore; + const uint8_t reg = normalizeRegIdx(reg_idx); + TaskDesc task = makeMemoryTask(port, fu, true, is_acc, base, stride, + rows, cols, width, transpose, reg); + + beginTask(task); + Fault fault = executeStore(xc, task); + if (fault != NoFault) { + abortTask(); + return fault; + } + + scheduleTimingTask(xc, task); + completeTask(); + return NoFault; +} + +Fault +MatrixController::loadA8( + ExecContext *xc, Addr base, Addr stride, uint64_t reg_idx) +{ + return load(xc, MemPort::A, base, stride, data.tileM, data.tileK, + ElemWidth::E8, false, reg_idx); +} + +Fault +MatrixController::loadB8( + ExecContext *xc, Addr base, Addr stride, uint64_t reg_idx) +{ + return load(xc, MemPort::B, base, stride, data.tileN, data.tileK, + ElemWidth::E8, false, reg_idx); +} + +Fault +MatrixController::loadC32( + ExecContext *xc, Addr base, Addr stride, uint64_t acc_idx) +{ + return load(xc, MemPort::CLoad, base, stride, data.tileM, data.tileN, + ElemWidth::E32, false, acc_idx); +} + +Fault +MatrixController::storeC32( + ExecContext *xc, Addr base, Addr stride, uint64_t acc_idx) +{ + return store(xc, MemPort::CStore, base, stride, data.tileM, data.tileN, + ElemWidth::E32, false, acc_idx); +} + +void +MatrixController::zeroAcc(ExecContext *xc, uint64_t acc_idx) +{ + zero(xc, acc_idx, true); +} + +void +MatrixController::zero(ExecContext *xc, uint64_t reg_idx, bool is_acc) +{ + const uint8_t reg = normalizeRegIdx(reg_idx); + TaskDesc task = makeZeroTask(reg, is_acc); + beginTask(task); + executeZero(task); + scheduleTimingTask(xc, task); + completeTask(); +} + +void +MatrixController::mmaccWB( + uint64_t src_a_idx, uint64_t src_b_idx, uint64_t dst_acc_idx) +{ + mmaccWB(nullptr, src_a_idx, src_b_idx, dst_acc_idx); +} + +void +MatrixController::mmaccWB(ExecContext *xc, + uint64_t src_a_idx, uint64_t src_b_idx, uint64_t dst_acc_idx) +{ + mmacc(xc, src_a_idx, src_b_idx, dst_acc_idx, data.tileM, data.tileN, + data.tileK); +} + +void +MatrixController::mmacc(ExecContext *xc, + uint64_t src_a_idx, uint64_t src_b_idx, uint64_t dst_acc_idx, + uint32_t rows, uint32_t cols, uint32_t depth) +{ + const uint8_t a_reg_idx = normalizeRegIdx(src_a_idx); + const uint8_t b_reg_idx = normalizeRegIdx(src_b_idx); + const uint8_t acc_reg_idx = normalizeRegIdx(dst_acc_idx); + TaskDesc task = makeMmaTask(a_reg_idx, b_reg_idx, acc_reg_idx, + rows, cols, depth); + beginTask(task); + executeMma(task); + scheduleTimingTask(xc, task); + completeTask(); +} + +MatrixController::TaskDesc +MatrixController::makeMemoryTask(MemPort port, FuType fu, bool is_store, + bool is_acc, Addr base, Addr stride, uint32_t rows, uint32_t cols, + ElemWidth elem_width, bool transpose, uint8_t reg_idx) const +{ + TaskDesc task; + const uint32_t bytes_per_row = cols * elemWidthBytes(elem_width); + task.op = TaskOp::LoadStore; + task.fu = fu; + task.memPort = port; + task.isStore = is_store; + task.transpose = transpose; + task.destValid = !is_store; + task.destIsAcc = is_acc; + task.destReg = reg_idx; + task.srcValid[0] = is_store; + task.srcIsAcc[0] = is_acc; + task.srcReg[0] = reg_idx; + task.srcReadPending[0] = is_store; + task.base = base; + task.stride = stride; + task.rows = rows; + task.cols = cols; + task.bytesPerRow = bytes_per_row; + task.elemWidth = elem_width; + task.memoryBytes = static_cast(rows) * bytes_per_row; + task.memoryRequests = matrixMemoryRequests(port, transpose, base, stride, + rows, cols, elem_width); + task.memoryBusBytes = task.memoryRequests * OutsideDataWidthBytes; + switch (port) { + case MemPort::A: + task.needMask = 0x1; + break; + case MemPort::B: + task.needMask = 0x2; + break; + case MemPort::CLoad: + case MemPort::CStore: + task.needMask = 0x4; + break; + default: + task.needMask = 0; + break; + } + return task; +} + +MatrixController::TaskDesc +MatrixController::makeMmaTask( + uint8_t src_a_idx, uint8_t src_b_idx, uint8_t dst_acc_idx, + uint32_t rows, uint32_t cols, uint32_t depth) const +{ + TaskDesc task; + task.op = TaskOp::Mma; + task.fu = FuType::Compute; + task.destValid = true; + task.destIsAcc = true; + task.destReg = dst_acc_idx; + task.srcValid[0] = true; + task.srcIsAcc[0] = false; + task.srcReg[0] = src_a_idx; + task.srcReadPending[0] = true; + task.srcValid[1] = true; + task.srcIsAcc[1] = false; + task.srcReg[1] = src_b_idx; + task.srcReadPending[1] = true; + task.srcValid[2] = true; + task.srcIsAcc[2] = true; + task.srcReg[2] = dst_acc_idx; + task.rows = rows; + task.cols = cols; + task.depth = depth; + task.needMask = 0x7; + return task; +} + +MatrixController::TaskDesc +MatrixController::makeZeroTask(uint8_t reg_idx, bool is_acc) const +{ + TaskDesc task; + task.op = TaskOp::Arith; + task.fu = is_acc ? FuType::CMLLoad : FuType::AML; + task.destValid = true; + task.destIsAcc = is_acc; + task.destReg = reg_idx; + task.rows = is_acc ? MatrixMaxM : MatrixMaxM; + task.cols = is_acc ? MatrixMaxN : MatrixMaxK; + task.needMask = is_acc ? 0x4 : 0x1; + return task; +} + +MatrixController::TaskDesc +MatrixController::makeReleaseTask(uint64_t token_idx) const +{ + TaskDesc task; + task.op = TaskOp::Release; + task.fu = FuType::None; + task.tokenIdx = token_idx; + return task; +} + +Fault +MatrixController::executeLoad(ExecContext *xc, const TaskDesc &task) +{ +#ifdef UNIT_TEST + panic("MatrixController memory loads are not linked in unit tests"); +#else + ThreadContext *tc = xc->tcBase(); + + if (task.destIsAcc) { + panic_if(task.elemWidth != ElemWidth::E32, + "C matrix functional load currently supports e32 only"); + panic_if(task.rows > MatrixMaxM || task.cols > MatrixMaxN, + "C matrix load shape exceeds accumulator register capacity"); + + auto &dst_reg = accReg(task.destReg); + for (uint32_t row = 0; row < task.rows; ++row) { + auto *dst = + reinterpret_cast(&dst_reg[row * MatrixMaxN]); + Fault fault = matrixReadBlob(tc, task.base + row * task.stride, + dst, task.bytesPerRow); + if (fault != NoFault) { + return fault; + } + } + return NoFault; + } + + panic_if(task.elemWidth != ElemWidth::E8, + "AB matrix functional load currently supports e8 only"); + panic_if(task.rows > MatrixMaxM || task.cols > MatrixMaxK, + "AB matrix load shape exceeds tile register capacity"); + + auto &dst_reg = abReg(task.destReg); + for (uint32_t row = 0; row < task.rows; ++row) { + auto *dst = reinterpret_cast(&dst_reg[row * MatrixMaxK]); + Fault fault = matrixReadBlob(tc, task.base + row * task.stride, dst, + task.bytesPerRow); + if (fault != NoFault) { + return fault; + } + } + return NoFault; +#endif +} + +Fault +MatrixController::executeStore(ExecContext *xc, const TaskDesc &task) +{ +#ifdef UNIT_TEST + panic("MatrixController memory stores are not linked in unit tests"); +#else + ThreadContext *tc = xc->tcBase(); + + if (task.srcIsAcc[0]) { + panic_if(task.elemWidth != ElemWidth::E32, + "C matrix functional store currently supports e32 only"); + panic_if(task.rows > MatrixMaxM || task.cols > MatrixMaxN, + "C matrix store shape exceeds accumulator register capacity"); + + const auto &src_reg = accReg(task.srcReg[0]); + for (uint32_t row = 0; row < task.rows; ++row) { + const auto *src = + reinterpret_cast(&src_reg[row * MatrixMaxN]); + Fault fault = matrixWriteBlob(tc, task.base + row * task.stride, + src, task.bytesPerRow); + if (fault != NoFault) { + return fault; + } + } + return NoFault; + } + + panic_if(task.elemWidth != ElemWidth::E8, + "AB matrix functional store currently supports e8 only"); + panic_if(task.rows > MatrixMaxM || task.cols > MatrixMaxK, + "AB matrix store shape exceeds tile register capacity"); + + const auto &src_reg = abReg(task.srcReg[0]); + for (uint32_t row = 0; row < task.rows; ++row) { + const auto *src = + reinterpret_cast(&src_reg[row * MatrixMaxK]); + Fault fault = matrixWriteBlob(tc, task.base + row * task.stride, src, + task.bytesPerRow); + if (fault != NoFault) { + return fault; + } + } + return NoFault; +#endif +} + +void +MatrixController::executeZero(const TaskDesc &task) +{ + if (task.destIsAcc) { + std::fill(accReg(task.destReg).begin(), accReg(task.destReg).end(), + 0); + } else { + std::fill(abReg(task.destReg).begin(), abReg(task.destReg).end(), 0); + } +} + +void +MatrixController::executeMma(const TaskDesc &task) +{ + panic_if(task.rows > MatrixMaxM || task.cols > MatrixMaxN || + task.depth > MatrixMaxK, "matrix mma shape exceeds register capacity"); + + const auto &a_reg = abReg(task.srcReg[0]); + const auto &b_reg = abReg(task.srcReg[1]); + auto &dst_reg = accReg(task.destReg); + for (uint32_t m = 0; m < task.rows; ++m) { + for (uint32_t n = 0; n < task.cols; ++n) { + int32_t acc = dst_reg[m * MatrixMaxN + n]; + for (uint32_t k = 0; k < task.depth; ++k) { + const int8_t a = a_reg[m * MatrixMaxK + k]; + const int8_t b = b_reg[n * MatrixMaxK + k]; + acc += static_cast(a) * static_cast(b); + } + dst_reg[m * MatrixMaxN + n] = acc; + } + } +} + +Tick +MatrixController::scheduleTimingTask(ExecContext *xc, const TaskDesc &task) +{ + const Tick now = curTick(); + + Tick issue_tick = std::max(now, timing.nextIssueTick); + if (task.op == TaskOp::Release) { + issue_tick = std::max(issue_tick, timing.pendingStoreReadyTick); + issue_tick = std::max(issue_tick, timing.lastCompletionTick); + } + + if (task.fu != FuType::None) { + issue_tick = std::max(issue_tick, timing.fuReadyTick[fuIndex(task.fu)]); + } + + if (task.destValid) { + const uint8_t dest = normalizeRegIdx(task.destReg); + const Tick write_ready = task.destIsAcc ? + timing.cWriteReadyTick[dest] : timing.abWriteReadyTick[dest]; + const Tick read_block = task.destIsAcc ? + timing.cReadBlockTick[dest] : timing.abReadBlockTick[dest]; + issue_tick = std::max(issue_tick, std::max(write_ready, read_block)); + } + + for (size_t i = 0; i < task.srcValid.size(); ++i) { + if (!task.srcValid[i]) { + continue; + } + + const uint8_t src = normalizeRegIdx(task.srcReg[i]); + const Tick write_ready = task.srcIsAcc[i] ? + timing.cWriteReadyTick[src] : timing.abWriteReadyTick[src]; + issue_tick = std::max(issue_tick, write_ready); + + if (task.isStore) { + const Tick read_block = task.srcIsAcc[i] ? + timing.cReadBlockTick[src] : timing.abReadBlockTick[src]; + issue_tick = std::max(issue_tick, read_block); + } + } + + const Tick complete_tick = task.op == TaskOp::LoadStore ? + scheduleMemoryPipeline(xc, task, issue_tick) : + issue_tick + taskLatencyTicks(xc, task); + timing.nextIssueTick = issue_tick + + cyclesToTicks(xc, timingConfig.issueIntervalCycles); + timing.lastIssueTick = issue_tick; + timing.lastCompletionTick = + std::max(timing.lastCompletionTick, complete_tick); + + if (task.fu != FuType::None) { + timing.fuReadyTick[fuIndex(task.fu)] = complete_tick; + } + + if (task.destValid) { + const uint8_t dest = normalizeRegIdx(task.destReg); + if (task.destIsAcc) { + timing.cWriteReadyTick[dest] = complete_tick; + } else { + timing.abWriteReadyTick[dest] = complete_tick; + } + } + + for (size_t i = 0; i < task.srcValid.size(); ++i) { + if (!task.srcValid[i] || !task.srcReadPending[i]) { + continue; + } + + const uint8_t src = normalizeRegIdx(task.srcReg[i]); + const Tick read_done = + taskSourceReadDoneTick(xc, task, issue_tick, complete_tick, i); + if (task.srcIsAcc[i]) { + timing.cReadBlockTick[src] = + std::max(timing.cReadBlockTick[src], read_done); + } else { + timing.abReadBlockTick[src] = + std::max(timing.abReadBlockTick[src], read_done); + } + } + + if (task.isStore) { + timing.pendingStoreReadyTick = + std::max(timing.pendingStoreReadyTick, complete_tick); + } + + const Tick queue_ticks = issue_tick > now ? issue_tick - now : 0; + ++stats.timingTasks; + stats.timingQueueTicks += queue_ticks; + stats.timingMaxQueueTicks = + std::max( + static_cast(stats.timingMaxQueueTicks.value()), + queue_ticks); + stats.timingBusyTicks += complete_tick - issue_tick; + stats.timingLastIssueTick = issue_tick; + stats.timingLastCompletionTick = timing.lastCompletionTick; + + return complete_tick; +} + +Tick +MatrixController::scheduleMemoryPipeline( + ExecContext *xc, const TaskDesc &task, Tick issue_tick) +{ + panic_if(task.op != TaskOp::LoadStore, + "matrix memory pipeline scheduled for non-memory task"); + + const uint64_t base_cycles = + task.isStore ? timingConfig.storeBaseCycles : + timingConfig.loadBaseCycles; + const Tick request_ready_tick = issue_tick + + cyclesToTicks(xc, base_cycles + timingConfig.localMmuArbCycles); + const Tick l2_request_pipe = + cyclesToTicks(xc, timingConfig.l2RequestPipelineCycles); + const Tick response_latency = cyclesToTicks(xc, + task.isStore ? timingConfig.localMmuWriteAckLatencyCycles : + timingConfig.localMmuReadLatencyCycles); + + Tick complete_tick = request_ready_tick; + stats.memoryPipelineRequests += task.memoryRequests; + + for (uint64_t i = 0; i < task.memoryRequests; ++i) { + Tick issue_candidate = request_ready_tick; + Tick request_tick = request_ready_tick; + Tick source_ready_tick = request_ready_tick; + uint8_t source = 0; + while (true) { + const Tick candidate_request_tick = + peekLocalMmuIssueSlot(xc, issue_candidate); + if (candidate_request_tick > issue_candidate) { + stats.memoryPipelineRequestQueueTicks += + candidate_request_tick - issue_candidate; + } + + source = chooseTimingLocalMmuSource(candidate_request_tick); + source_ready_tick = timing.localMmuSourceReadyTick[source]; + if (source_ready_tick <= candidate_request_tick) { + request_tick = reserveLocalMmuIssueSlot(xc, issue_candidate); + panic_if(request_tick != candidate_request_tick, + "matrix LocalMMU issue slot changed after source " + "selection"); + break; + } + + stats.memoryPipelineSourceStallTicks += + source_ready_tick - candidate_request_tick; + issue_candidate = source_ready_tick; + } + + const Tick response_candidate = + request_tick + l2_request_pipe + response_latency; + const Tick response_tick = + reserveLocalMmuResponseSlot(xc, response_candidate); + if (response_tick > response_candidate) { + stats.memoryPipelineResponseQueueTicks += + response_tick - response_candidate; + } + + timing.localMmuLastRequestTick = + std::max(timing.localMmuLastRequestTick, request_tick); + timing.localMmuLastResponseTick = + std::max(timing.localMmuLastResponseTick, response_tick); + + const Tick source_reuse_tick = + response_tick + cyclesToTicks(xc, + timingConfig.l2ResponsePipelineCycles); + timing.localMmuSourceReadyTick[source] = source_reuse_tick; + const uint32_t outstanding = timingLocalMmuOutstanding(request_tick); + stats.memoryPipelineMaxOutstanding = std::max( + static_cast( + stats.memoryPipelineMaxOutstanding.value()), + outstanding); + + if (task.isStore) { + ++stats.memoryPipelineWriteAcks; + } else { + ++stats.memoryPipelineReadResponses; + } + stats.memoryPipelineLastRequestTick = + timing.localMmuLastRequestTick; + stats.memoryPipelineLastResponseTick = + timing.localMmuLastResponseTick; + + complete_tick = std::max(complete_tick, response_tick); + } + + return std::max(complete_tick, issue_tick + cyclesToTicks(xc, 1)); +} + +Tick +MatrixController::taskLatencyTicks(ExecContext *xc, const TaskDesc &task) const +{ + uint64_t cycles = 1; + switch (task.op) { + case TaskOp::LoadStore: + panic("matrix load/store timing must use scheduleMemoryPipeline"); + break; + case TaskOp::Mma: + cycles = timingConfig.computeBaseCycles; + break; + case TaskOp::Arith: + cycles = timingConfig.zeroCycles; + break; + case TaskOp::Release: + cycles = timingConfig.releaseCycles; + break; + } + + return cyclesToTicks(xc, std::max(cycles, 1)); +} + +Tick +MatrixController::taskSourceReadDoneTick( + ExecContext *xc, const TaskDesc &task, Tick issue_tick, + Tick complete_tick, size_t src_idx) const +{ + if (task.op == TaskOp::Mma && src_idx < 2) { + return std::min(complete_tick, + issue_tick + cyclesToTicks(xc, timingConfig.computeReadCycles)); + } + + if (task.op == TaskOp::Mma && src_idx == 2) { + return std::min(complete_tick, + issue_tick + cyclesToTicks(xc, timingConfig.computeReadCycles)); + } + + return complete_tick; +} + +Tick +MatrixController::cpuCycleTicks(ExecContext *xc) const +{ +#ifdef UNIT_TEST + return 1; +#else + if (xc != nullptr && xc->tcBase() != nullptr && + xc->tcBase()->getCpuPtr() != nullptr) { + return xc->tcBase()->getCpuPtr()->clockPeriod(); + } + return 1; +#endif +} + +Tick +MatrixController::cyclesToTicks(ExecContext *xc, uint64_t cycles) const +{ + return cpuCycleTicks(xc) * cycles; +} + +Tick +MatrixController::peekLocalMmuIssueSlot( + ExecContext *xc, Tick earliest_tick) const +{ + return peekIssueSlot(earliest_tick, + timing.localMmuIssueSlotTick, timing.localMmuIssueSlotsUsed, + timingConfig.localMmuIssuePerCycle, cyclesToTicks(xc, 1)); +} + +Tick +MatrixController::reserveLocalMmuIssueSlot( + ExecContext *xc, Tick earliest_tick) +{ + return reserveIssueSlot(earliest_tick, + timing.localMmuIssueSlotTick, timing.localMmuIssueSlotsUsed, + timingConfig.localMmuIssuePerCycle, cyclesToTicks(xc, 1)); +} + +Tick +MatrixController::reserveLocalMmuResponseSlot( + ExecContext *xc, Tick earliest_tick) +{ + const Tick response_tick = + std::max(earliest_tick, timing.localMmuResponseSlotTick); + timing.localMmuResponseSlotTick = response_tick + + cyclesToTicks(xc, timingConfig.l2ResponsePipelineCycles); + return response_tick; +} + +uint8_t +MatrixController::chooseTimingLocalMmuSource(Tick earliest_tick) const +{ + uint8_t source = 0; + Tick selected_ready_tick = MaxTick; + bool found_ready_source = false; + + for (uint32_t i = 0; i < LocalMmuSourceCount; ++i) { + const Tick ready_tick = timing.localMmuSourceReadyTick[i]; + if (ready_tick <= earliest_tick) { + source = static_cast(i); + selected_ready_tick = ready_tick; + found_ready_source = true; + continue; + } + + if (!found_ready_source && + (ready_tick < selected_ready_tick || + (ready_tick == selected_ready_tick && i > source))) { + source = static_cast(i); + selected_ready_tick = ready_tick; + } + } + + return source; +} + +uint32_t +MatrixController::timingLocalMmuOutstanding(Tick tick) const +{ + uint32_t count = 0; + for (Tick ready_tick : timing.localMmuSourceReadyTick) { + if (ready_tick > tick) { + ++count; + } + } + return count; +} + +void +MatrixController::retireReadyTokens(Tick now) +{ + for (size_t i = 0; i < TokenCount; ++i) { + auto &ready_ticks = timing.tokenReadyTicks[i]; + const auto ready_end = + std::upper_bound(ready_ticks.begin(), ready_ticks.end(), now); + const size_t ready_count = ready_end - ready_ticks.begin(); + if (ready_count == 0) { + continue; + } + + token(i) += ready_count; + ready_ticks.erase(ready_ticks.begin(), ready_end); + } +} + +void +MatrixController::enqueueTokenRelease(uint64_t token_idx, Tick ready_tick) +{ + panic_if(token_idx >= TokenCount, + "matrix token index %llu out of range", + static_cast(token_idx)); + + auto &ready_ticks = timing.tokenReadyTicks[token_idx]; + ready_ticks.insert(std::upper_bound( + ready_ticks.begin(), ready_ticks.end(), ready_tick), ready_tick); + ++stats.tokenReleaseEvents; +} + +void +MatrixController::delayPendingTokenEvents(Tick ready_tick) +{ + if (ready_tick == 0) { + return; + } + + for (auto &ready_ticks : timing.tokenReadyTicks) { + for (Tick &tick : ready_ticks) { + if (tick >= ready_tick) { + continue; + } + + stats.tokenReleaseDelayTicks += ready_tick - tick; + tick = ready_tick; + ++stats.tokenReleaseDelayEvents; + } + std::sort(ready_ticks.begin(), ready_ticks.end()); + } +} + +Tick +MatrixController::tokenTargetReadyTick(uint64_t token_idx, uint64_t target) const +{ + panic_if(token_idx >= TokenCount, + "matrix token index %llu out of range", + static_cast(token_idx)); + + const RegVal observed = token(token_idx); + if (observed >= target) { + return curTick(); + } + + const uint64_t pending_needed = target - observed; + const auto &ready_ticks = timing.tokenReadyTicks[token_idx]; + if (pending_needed == 0 || pending_needed > ready_ticks.size()) { + return MaxTick; + } + return ready_ticks[pending_needed - 1]; +} + +uint16_t +MatrixController::pendingTokenEvents() const +{ + size_t pending = 0; + for (const auto &ready_ticks : timing.tokenReadyTicks) { + pending += ready_ticks.size(); + } + return pending > UINT16_MAX ? UINT16_MAX : static_cast(pending); +} + +void +MatrixController::beginTask(const TaskDesc &task) +{ + panic_if(queueFull(), "matrix decoded FIFO full"); + + TaskDesc queued = task; + assignFifoIdx(queued); + pushTask(queued); + const TaskDesc &head = headTask(); + panic_if(!canIssue(head), "matrix controller task cannot issue"); + + reserveTask(head); + recordIssue(head); +} + +void +MatrixController::completeTask() +{ + panic_if(queueEmpty(), "matrix task completion with empty decoded FIFO"); + + const TaskDesc &task = headTask(); + recordCompletion(task); + releaseTask(task); + ++stats.tasksCompleted; + popTask(); +} + +void +MatrixController::abortTask() +{ + panic_if(queueEmpty(), "matrix task abort with empty decoded FIFO"); + + const TaskDesc &task = headTask(); + releaseTask(task); + ++stats.tasksAborted; + popTask(); +} + +void +MatrixController::assignFifoIdx(TaskDesc &task) +{ + switch (task.op) { + case TaskOp::Mma: + task.fifoIdx = control.nextComputeFifoIdx; + control.nextComputeFifoIdx = + (control.nextComputeFifoIdx + 1) & (MicroTaskFifoSlots - 1); + break; + case TaskOp::LoadStore: + if (task.isStore) { + task.fifoIdx = control.nextStoreFifoIdx; + control.nextStoreFifoIdx = + (control.nextStoreFifoIdx + 1) & (MicroTaskFifoSlots - 1); + } else { + task.fifoIdx = control.nextLoadFifoIdx; + control.nextLoadFifoIdx = + (control.nextLoadFifoIdx + 1) & (MicroTaskFifoSlots - 1); + } + break; + case TaskOp::Arith: + task.fifoIdx = control.nextLoadFifoIdx; + control.nextLoadFifoIdx = + (control.nextLoadFifoIdx + 1) & (MicroTaskFifoSlots - 1); + break; + case TaskOp::Release: + task.fifoIdx = 0; + break; + } +} + +bool +MatrixController::queueFull() const +{ + return control.queueSize == DecodedQueueDepth; +} + +bool +MatrixController::queueEmpty() const +{ + return control.queueSize == 0; +} + +void +MatrixController::pushTask(const TaskDesc &task) +{ + const uint8_t tail = (control.queueHead + control.queueSize) & + (DecodedQueueDepth - 1); + control.decodedQueue[tail] = task; + ++control.queueSize; +} + +MatrixController::TaskDesc & +MatrixController::headTask() +{ + return control.decodedQueue[control.queueHead]; +} + +const MatrixController::TaskDesc & +MatrixController::headTask() const +{ + return control.decodedQueue[control.queueHead]; +} + +void +MatrixController::popTask() +{ + panic_if(queueEmpty(), "matrix decoded FIFO pop on empty queue"); + + control.queueHead = (control.queueHead + 1) & (DecodedQueueDepth - 1); + --control.queueSize; +} + +bool +MatrixController::canIssue(const TaskDesc &task) const +{ + if (task.op == TaskOp::Release) { + return !control.pendingStore; + } + + if (fuBusy(task.fu)) { + return false; + } + + if (task.destValid && + (regBusy(task.destIsAcc, task.destReg) || + regHasPendingReaders(task.destIsAcc, task.destReg))) { + return false; + } + + for (size_t i = 0; i < task.srcValid.size(); ++i) { + if (task.srcValid[i] && regBusy(task.srcIsAcc[i], task.srcReg[i])) { + return false; + } + } + + if (task.isStore && task.srcValid[0] && + regHasPendingReaders(task.srcIsAcc[0], task.srcReg[0])) { + return false; + } + + return true; +} + +void +MatrixController::reserveTask(const TaskDesc &task) +{ + if (task.fu != FuType::None) { + setFuBusy(task.fu, true); + control.fus[fuIndex(task.fu)].fifoIdx = task.fifoIdx; + } + + if (task.destValid) { + reserveReg(task.destIsAcc, task.destReg, task.fu); + } + + for (size_t i = 0; i < task.srcValid.size(); ++i) { + if (task.srcValid[i] && task.srcReadPending[i]) { + addRegReader(task.srcIsAcc[i], task.srcReg[i]); + } + } + + if (task.isStore) { + control.pendingStore = true; + } +} + +void +MatrixController::releaseTask(const TaskDesc &task) +{ + for (size_t i = 0; i < task.srcValid.size(); ++i) { + if (task.srcValid[i] && task.srcReadPending[i]) { + removeRegReader(task.srcIsAcc[i], task.srcReg[i]); + } + } + + if (task.destValid) { + releaseReg(task.destIsAcc, task.destReg, task.fu); + } + + if (task.isStore) { + control.pendingStore = false; + } + + if (task.fu != FuType::None) { + setFuBusy(task.fu, false); + } +} + +bool +MatrixController::regBusy(bool is_acc, uint8_t reg_idx) const +{ + const uint8_t idx = normalizeRegIdx(reg_idx); + return is_acc ? control.cRegs[idx].busy : control.abRegs[idx].busy; +} + +bool +MatrixController::regHasPendingReaders(bool is_acc, uint8_t reg_idx) const +{ + const uint8_t idx = normalizeRegIdx(reg_idx); + return is_acc ? control.cRegs[idx].pendingReaders != 0 : + control.abRegs[idx].pendingReaders != 0; +} + +void +MatrixController::reserveReg(bool is_acc, uint8_t reg_idx, FuType fu) +{ + RegStatus &status = is_acc ? control.cRegs[normalizeRegIdx(reg_idx)] : + control.abRegs[normalizeRegIdx(reg_idx)]; + panic_if(status.busy, "matrix register reserved while busy"); + status.busy = true; + status.producer = fu; +} + +void +MatrixController::releaseReg(bool is_acc, uint8_t reg_idx, FuType fu) +{ + RegStatus &status = is_acc ? control.cRegs[normalizeRegIdx(reg_idx)] : + control.abRegs[normalizeRegIdx(reg_idx)]; + panic_if(!status.busy, "matrix register released while idle"); + panic_if(status.producer != fu, "matrix register producer mismatch"); + status.busy = false; + status.producer = FuType::None; +} + +void +MatrixController::addRegReader(bool is_acc, uint8_t reg_idx) +{ + RegStatus &status = is_acc ? control.cRegs[normalizeRegIdx(reg_idx)] : + control.abRegs[normalizeRegIdx(reg_idx)]; + panic_if(status.pendingReaders == UINT8_MAX, + "matrix register pending reader count overflow"); + ++status.pendingReaders; +} + +void +MatrixController::removeRegReader(bool is_acc, uint8_t reg_idx) +{ + RegStatus &status = is_acc ? control.cRegs[normalizeRegIdx(reg_idx)] : + control.abRegs[normalizeRegIdx(reg_idx)]; + panic_if(status.pendingReaders == 0, + "matrix register pending reader count underflow"); + --status.pendingReaders; +} + +bool +MatrixController::fuBusy(FuType fu) const +{ + if (fu == FuType::None) { + return false; + } + return control.fus[fuIndex(fu)].busy; +} + +void +MatrixController::setFuBusy(FuType fu, bool busy) +{ + if (fu == FuType::None) { + return; + } + + FuStatus &status = control.fus[fuIndex(fu)]; + panic_if(status.busy == busy, "matrix FU busy state transition invalid"); + status.busy = busy; +} + +void +MatrixController::recordIssue(const TaskDesc &task) +{ + ++stats.tasksAccepted; + ++stats.tasksIssued; + + switch (task.op) { + case TaskOp::Mma: + ++stats.mmaTasks; + ++stats.taskEvents[taskEventIndex(TaskEvent::ComputeIssue)]; + break; + case TaskOp::Arith: + ++stats.zeroTasks; + ++stats.taskEvents[taskEventIndex(TaskEvent::LoadAllocate)]; + ++stats.taskEvents[taskEventIndex(TaskEvent::LoadIssue)]; + break; + case TaskOp::Release: + ++stats.releaseTasks; + ++stats.taskEvents[taskEventIndex(TaskEvent::ReleaseIssue)]; + break; + case TaskOp::LoadStore: + switch (task.memPort) { + case MemPort::A: + ++stats.aPortTasks; + break; + case MemPort::B: + ++stats.bPortTasks; + break; + case MemPort::CLoad: + ++stats.cLoadTasks; + break; + case MemPort::CStore: + ++stats.cStoreTasks; + break; + default: + break; + } + if (task.isStore) { + ++stats.taskEvents[taskEventIndex(TaskEvent::StoreIssue)]; + } else { + ++stats.taskEvents[taskEventIndex(TaskEvent::LoadAllocate)]; + ++stats.taskEvents[taskEventIndex(TaskEvent::LoadIssue)]; + } + recordMemoryRequests(task); + break; + } +} + +void +MatrixController::recordCompletion(const TaskDesc &task) +{ + switch (task.op) { + case TaskOp::Mma: + stats.taskEvents[taskEventIndex(TaskEvent::ComputeReadFinish)] += 2; + ++stats.taskEvents[taskEventIndex(TaskEvent::ComputeWriteFinish)]; + break; + case TaskOp::Arith: + ++stats.taskEvents[taskEventIndex(TaskEvent::LoadFinish)]; + break; + case TaskOp::LoadStore: + if (task.isStore) { + ++stats.taskEvents[taskEventIndex(TaskEvent::StoreFinish)]; + } else { + ++stats.taskEvents[taskEventIndex(TaskEvent::LoadFinish)]; + } + break; + case TaskOp::Release: + break; + } +} + +void +MatrixController::recordMemoryRequests(const TaskDesc &task) +{ + if (task.memoryRequests == 0 || task.memPort == MemPort::Num) { + return; + } + + stats.memoryRequests += task.memoryRequests; + stats.memoryBytes += task.memoryBytes; + stats.memoryBusBytes += task.memoryBusBytes; + + if (task.isStore) { + stats.memoryWriteRequests += task.memoryRequests; + stats.memoryWriteBytes += task.memoryBytes; + stats.memoryWriteBusBytes += task.memoryBusBytes; + } else { + stats.memoryReadRequests += task.memoryRequests; + stats.memoryReadBytes += task.memoryBytes; + stats.memoryReadBusBytes += task.memoryBusBytes; + } + + std::array inflight_sources = {}; + uint32_t inflight_count = 0; + + for (uint64_t i = 0; i < task.memoryRequests; ++i) { + const MemPort selected_port = arbitrateLocalMmuPort(task.memPort); + ++stats.memPortRequests[memPortIndex(selected_port)]; + inflight_sources[inflight_count++] = + allocateLocalMmuSource(selected_port); + + if (inflight_count == LocalMmuSourceCount) { + for (uint32_t j = 0; j < inflight_count; ++j) { + completeLocalMmuRequest(inflight_sources[j], task.isStore); + } + inflight_count = 0; + } + } + + for (uint32_t i = 0; i < inflight_count; ++i) { + completeLocalMmuRequest(inflight_sources[i], task.isStore); + } +} + +MatrixController::MemPort +MatrixController::arbitrateLocalMmuPort(MemPort requested) +{ + const uint8_t requested_idx = static_cast(memPortIndex(requested)); + + for (uint8_t offset = 0; offset < MemPortCount; ++offset) { + const uint8_t candidate = static_cast( + (control.firstMmuRequestIndex + offset) % MemPortCount); + if (candidate == requested_idx) { + control.firstMmuRequestIndex = + static_cast((candidate + 1) % MemPortCount); + ++stats.localMmuArbitrations; + ++stats.localMmuPortSelections[candidate]; + return requested; + } + } + + panic("matrix LocalMMU arbitration missed requested port"); +} + +void +MatrixController::completeLocalMmuRequest(uint8_t source, bool is_store) +{ + panic_if(source >= LocalMmuSourceCount, + "matrix LocalMMU response source out of range"); + panic_if(!control.sourceBusy[source], + "matrix LocalMMU response for idle source"); + + const MemPort port = control.sourceToPort[source]; + panic_if(port == MemPort::Num, + "matrix LocalMMU response has no routed port"); + + const size_t port_idx = memPortIndex(port); + if (is_store) { + ++stats.localMmuWriteAcks; + ++stats.memPortWriteAcks[port_idx]; + ++stats.taskEvents[taskEventIndex(TaskEvent::MemoryWriteAckResponse)]; + } else { + ++stats.localMmuReadDataResponses; + ++stats.memPortReadDataResponses[port_idx]; + ++stats.taskEvents[taskEventIndex(TaskEvent::MemoryReadDataResponse)]; + } + + releaseLocalMmuSource(source); +} + +uint8_t +MatrixController::allocateLocalMmuSource(MemPort port) +{ + const uint8_t source = selectHighestFreeSource(control.sourceBusy); + panic_if(source >= LocalMmuSourceCount, + "matrix LocalMMU has no free source id"); + + control.sourceBusy[source] = true; + control.sourceToPort[source] = port; + ++control.localMmuOutstanding; + ++stats.localMmuSourceAllocations; + stats.localMmuMaxOutstanding = std::max( + static_cast(stats.localMmuMaxOutstanding.value()), + control.localMmuOutstanding); + return source; +} + +void +MatrixController::releaseLocalMmuSource(uint8_t source) +{ + panic_if(source >= LocalMmuSourceCount, + "matrix LocalMMU release source out of range"); + panic_if(!control.sourceBusy[source], + "matrix LocalMMU release for idle source"); + panic_if(control.localMmuOutstanding == 0, + "matrix LocalMMU outstanding count underflow"); + + control.sourceBusy[source] = false; + control.sourceToPort[source] = MemPort::Num; + --control.localMmuOutstanding; + ++stats.localMmuSourceReleases; +} + +uint8_t +MatrixController::normalizeRegIdx(uint64_t reg_idx) +{ + return static_cast(reg_idx & (MatrixRegCount - 1)); +} + +uint32_t +MatrixController::clampTileM(uint64_t value) +{ + return value > MatrixMaxM ? MatrixMaxM : static_cast(value); +} + +uint32_t +MatrixController::clampTileK(uint64_t value) +{ + return value > MatrixMaxK ? MatrixMaxK : static_cast(value); +} + +uint32_t +MatrixController::clampTileN(uint64_t value) +{ + return value > MatrixMaxN ? MatrixMaxN : static_cast(value); +} + +uint64_t +MatrixController::divCeil(uint64_t numerator, uint64_t denominator) +{ + return numerator == 0 ? 0 : (numerator + denominator - 1) / denominator; +} + +MatrixController::MemoryRequestShape +MatrixController::matrixMemoryRequestShape( + MemPort port, bool transpose, uint32_t rows, uint32_t cols, + ElemWidth width) +{ + const uint32_t elem_bytes = elemWidthBytes(width); + if (rows == 0 || cols == 0) { + return {}; + } + + if (port == MemPort::A || port == MemPort::B || port == MemPort::CLoad) { + return {rows, static_cast(cols) * elem_bytes}; + } + + if (port == MemPort::CStore) { + const uint64_t major_dim = transpose ? cols : rows; + const uint64_t reduce_dim = transpose ? rows : cols; + const uint64_t rounded_major = + divCeil(major_dim, MatrixMN) * MatrixMN; + return {rounded_major, reduce_dim * elem_bytes}; + } + + return {rows, static_cast(cols) * elem_bytes}; +} + +uint64_t +MatrixController::matrixMemoryRequests(MemPort port, bool transpose, + Addr base, Addr stride, uint32_t rows, uint32_t cols, ElemWidth width) +{ + const MemoryRequestShape shape = + matrixMemoryRequestShape(port, transpose, rows, cols, width); + return rowRequests(base, stride, shape.rows, shape.bytesPerRow); +} + +uint64_t +MatrixController::rowRequests( + Addr base, Addr stride, uint64_t rows, uint64_t bytes_per_row) +{ + if (rows == 0 || bytes_per_row == 0) { + return 0; + } + + uint64_t requests = 0; + for (uint64_t row = 0; row < rows; ++row) { + const Addr row_base = base + static_cast(row) * stride; + const uint64_t offset = row_base & (OutsideDataWidthBytes - 1); + requests += divCeil(offset + bytes_per_row, OutsideDataWidthBytes); + } + return requests; +} + +uint32_t +MatrixController::elemWidthBytes(ElemWidth width) +{ + switch (width) { + case ElemWidth::E8: + return 1; + case ElemWidth::E16: + return 2; + case ElemWidth::E32: + return 4; + case ElemWidth::E64: + return 8; + } + + panic("unsupported matrix element width"); +} + +size_t +MatrixController::fuIndex(FuType fu) +{ + return static_cast(fu); +} + +size_t +MatrixController::memPortIndex(MemPort port) +{ + return static_cast(port); +} + +size_t +MatrixController::taskEventIndex(TaskEvent event) +{ + return static_cast(event); +} + +} // namespace matrix +} // namespace gem5 diff --git a/src/matrix/matrix_controller.hh b/src/matrix/matrix_controller.hh new file mode 100644 index 0000000000..d78c86551c --- /dev/null +++ b/src/matrix/matrix_controller.hh @@ -0,0 +1,521 @@ +/* + * XSAI CUTE-aligned matrix controller scaffold. + */ + +#ifndef __MATRIX_MATRIX_CONTROLLER_HH__ +#define __MATRIX_MATRIX_CONTROLLER_HH__ + +#include +#include +#include +#include + +#include "base/statistics.hh" +#include "base/types.hh" +#include "sim/faults.hh" +#include "sim/serialize.hh" + +namespace gem5 +{ + +class ExecContext; + +namespace matrix +{ + +class MatrixController +{ + public: + static constexpr uint32_t TokenCount = 32; + static constexpr uint32_t MatrixRegCount = 4; + static constexpr uint32_t DecodedQueueDepth = 8; + static constexpr uint32_t MicroTaskFifoSlots = 4; + static constexpr uint32_t LocalMmuSourceCount = 64; + static constexpr uint32_t OutsideDataWidthBits = 512; + static constexpr uint32_t OutsideDataWidthBytes = + OutsideDataWidthBits / 8; + static constexpr uint32_t ReduceWidthBytes = 64; + static constexpr uint32_t ResultWidthBytes = 4; + static constexpr uint32_t MatrixMN = 4; + + static constexpr uint32_t MatrixMaxM = 128; + static constexpr uint32_t MatrixMaxK = 64; + static constexpr uint32_t MatrixMaxN = 128; + static constexpr uint32_t MatrixABRegBytes = MatrixMaxM * MatrixMaxK; + static constexpr uint32_t MatrixAccElems = MatrixMaxM * MatrixMaxN; + + static constexpr uint8_t DefaultAReg = 0; + static constexpr uint8_t DefaultBReg = 1; + static constexpr uint8_t DefaultAccReg = 0; + + enum class TaskOp : uint8_t + { + Mma, + LoadStore, + Release, + Arith, + }; + + enum class FuType : uint8_t + { + None = 0, + AML, + BML, + CMLLoad, + CMLStore, + Compute, + Num + }; + + enum class MemPort : uint8_t + { + A = 0, + B, + CLoad, + CStore, + BScale, + AScale, + Num + }; + + static_assert(static_cast(MemPort::A) == 0 && + static_cast(MemPort::B) == 1 && + static_cast(MemPort::CLoad) == 2 && + static_cast(MemPort::CStore) == 3 && + static_cast(MemPort::BScale) == 4 && + static_cast(MemPort::AScale) == 5, + "matrix MemPort order must match RTL LocalMMUTaskType encoding"); + + enum class TaskEvent : uint8_t + { + LoadAllocate = 0, + LoadIssue, + LoadFinish, + ComputeIssue, + ComputeReadFinish, + ComputeWriteFinish, + StoreIssue, + StoreFinish, + ReleaseIssue, + MemoryReadDataResponse, + MemoryWriteAckResponse, + Num + }; + + enum class ElemWidth : uint8_t + { + E8 = 1, + E16 = 2, + E32 = 4, + E64 = 8, + }; + + static constexpr size_t FuTypeCount = static_cast(FuType::Num); + static constexpr size_t MemPortCount = static_cast(MemPort::Num); + static constexpr size_t TaskEventCount = + static_cast(TaskEvent::Num); + + static_assert((DecodedQueueDepth & (DecodedQueueDepth - 1)) == 0, + "matrix decoded FIFO depth must be a power of two"); + static_assert((MicroTaskFifoSlots & (MicroTaskFifoSlots - 1)) == 0, + "matrix micro-task FIFO slots must be a power of two"); + static_assert((LocalMmuSourceCount & (LocalMmuSourceCount - 1)) == 0, + "matrix LocalMMU source count must be a power of two"); + static_assert((OutsideDataWidthBytes & (OutsideDataWidthBytes - 1)) == 0, + "matrix outside data width must be a power of two"); + static_assert((MatrixMN & (MatrixMN - 1)) == 0, + "matrix Matrix_MN must be a power of two"); + + struct Stats : public statistics::Group + { + statistics::Scalar tasksAccepted; + statistics::Scalar tasksIssued; + statistics::Scalar tasksCompleted; + statistics::Scalar tasksAborted; + + statistics::Scalar aPortTasks; + statistics::Scalar bPortTasks; + statistics::Scalar cLoadTasks; + statistics::Scalar cStoreTasks; + statistics::Scalar mmaTasks; + statistics::Scalar zeroTasks; + statistics::Scalar releaseTasks; + + statistics::Scalar memoryRequests; + statistics::Scalar memoryReadRequests; + statistics::Scalar memoryWriteRequests; + statistics::Scalar memoryBytes; + statistics::Scalar memoryReadBytes; + statistics::Scalar memoryWriteBytes; + statistics::Scalar memoryBusBytes; + statistics::Scalar memoryReadBusBytes; + statistics::Scalar memoryWriteBusBytes; + statistics::Scalar localMmuSourceAllocations; + statistics::Scalar localMmuSourceReleases; + statistics::Scalar localMmuArbitrations; + statistics::Scalar localMmuReadDataResponses; + statistics::Scalar localMmuWriteAcks; + statistics::Scalar localMmuMaxOutstanding; + statistics::Scalar memoryPipelineRequests; + statistics::Scalar memoryPipelineReadResponses; + statistics::Scalar memoryPipelineWriteAcks; + statistics::Scalar memoryPipelineSourceStallTicks; + statistics::Scalar memoryPipelineRequestQueueTicks; + statistics::Scalar memoryPipelineResponseQueueTicks; + statistics::Scalar memoryPipelineLastRequestTick; + statistics::Scalar memoryPipelineLastResponseTick; + statistics::Scalar memoryPipelineMaxOutstanding; + statistics::Scalar timingTasks; + statistics::Scalar timingQueueTicks; + statistics::Scalar timingMaxQueueTicks; + statistics::Scalar timingBusyTicks; + statistics::Scalar timingLastIssueTick; + statistics::Scalar timingLastCompletionTick; + statistics::Scalar acquireStallEvents; + statistics::Scalar acquireStallTicks; + statistics::Scalar tokenReleaseEvents; + statistics::Scalar tokenReleaseDelayEvents; + statistics::Scalar tokenReleaseDelayTicks; + std::array memPortRequests = {}; + std::array memPortReadDataResponses = {}; + std::array memPortWriteAcks = {}; + std::array localMmuPortSelections = {}; + std::array taskEvents = {}; + + Stats(statistics::Group *parent); + void resetStats() override; + }; + + struct TimingConfig + { + uint32_t issueIntervalCycles = 1; + uint32_t loadBaseCycles = 4; + uint32_t storeBaseCycles = 4; + uint32_t zeroCycles = 1; + uint32_t computeBaseCycles = 2; + uint32_t computeReadCycles = 1; + uint32_t releaseCycles = 1; + uint32_t localMmuIssuePerCycle = 1; + uint32_t localMmuArbCycles = 1; + uint32_t l2RequestPipelineCycles = 1; + uint32_t l2ResponsePipelineCycles = 1; + uint32_t localMmuReadLatencyCycles = 20; + uint32_t localMmuWriteAckLatencyCycles = 12; + }; + + struct ControlSnapshot + { + uint8_t decodedQueueHead = 0; + uint8_t decodedQueueSize = 0; + uint8_t nextLoadFifoIdx = 0; + uint8_t nextComputeFifoIdx = 0; + uint8_t nextStoreFifoIdx = 0; + + uint8_t fuBusyMask = 0; + uint8_t abBusyMask = 0; + uint8_t cBusyMask = 0; + uint8_t abPendingReaderMask = 0; + uint8_t cPendingReaderMask = 0; + std::array abPendingReaders = {}; + std::array cPendingReaders = {}; + + bool pendingStore = false; + uint8_t firstMmuRequestIndex = 0; + uint8_t nextLocalMmuSource = 0; + uint8_t localMmuOutstanding = 0; + uint64_t localMmuBusySourceMask = 0; + + Tick timingNextIssueTick = 0; + Tick timingPendingStoreReadyTick = 0; + Tick timingLastIssueTick = 0; + Tick timingLastCompletionTick = 0; + Tick timingLocalMmuIssueSlotTick = 0; + Tick timingLocalMmuResponseSlotTick = 0; + Tick timingLocalMmuLastRequestTick = 0; + Tick timingLocalMmuLastResponseTick = 0; + uint8_t timingLocalMmuOutstanding = 0; + uint16_t pendingTokenEvents = 0; + }; + + MatrixController(statistics::Group *parent); + + void reset(); + void setTimingConfig(const TimingConfig &config); + void retireReadyTokensUpTo(Tick now); + bool tokenTargetReached(uint64_t token_idx, uint64_t target) const; + Tick tokenReadyTick(uint64_t token_idx, uint64_t target) const; + void recordAcquireStall(Tick ticks); + + void serialize(CheckpointOut &cp) const; + void unserialize(CheckpointIn &cp); + + void syncReset(uint64_t token_idx); + void release(ExecContext *xc, uint64_t token_idx); + void acquire(ExecContext *xc, uint64_t token_idx, uint64_t target); + + void setTileM(uint64_t value); + void setTileK(uint64_t value); + void setTileN(uint64_t value); + uint32_t getTileM() const { return data.tileM; } + uint32_t getTileK() const { return data.tileK; } + uint32_t getTileN() const { return data.tileN; } + + Fault load(ExecContext *xc, MemPort port, Addr base, Addr stride, + uint32_t rows, uint32_t cols, ElemWidth width, bool transpose, + uint64_t reg_idx); + Fault store(ExecContext *xc, MemPort port, Addr base, Addr stride, + uint32_t rows, uint32_t cols, ElemWidth width, bool transpose, + uint64_t reg_idx); + Fault loadA8(ExecContext *xc, Addr base, Addr stride, + uint64_t reg_idx = DefaultAReg); + Fault loadB8(ExecContext *xc, Addr base, Addr stride, + uint64_t reg_idx = DefaultBReg); + Fault loadC32(ExecContext *xc, Addr base, Addr stride, + uint64_t acc_idx = DefaultAccReg); + Fault storeC32(ExecContext *xc, Addr base, Addr stride, + uint64_t acc_idx = DefaultAccReg); + void zeroAcc(ExecContext *xc, uint64_t acc_idx = DefaultAccReg); + void zero(ExecContext *xc, uint64_t reg_idx, bool is_acc); + void mmaccWB(uint64_t src_a_idx = DefaultAReg, + uint64_t src_b_idx = DefaultBReg, + uint64_t dst_acc_idx = DefaultAccReg); + void mmaccWB(ExecContext *xc, uint64_t src_a_idx = DefaultAReg, + uint64_t src_b_idx = DefaultBReg, + uint64_t dst_acc_idx = DefaultAccReg); + void mmacc(ExecContext *xc, uint64_t src_a_idx, uint64_t src_b_idx, + uint64_t dst_acc_idx, uint32_t rows, uint32_t cols, uint32_t depth); + + const Stats &getStats() const { return stats; } + ControlSnapshot controlSnapshot() const; + +#ifdef UNIT_TEST + uint8_t allocateLocalMmuSourceForTest(MemPort port); + void releaseLocalMmuSourceForTest(uint8_t source); + void writeABRegForTest( + uint64_t reg_idx, uint32_t row, uint32_t col, int8_t value); + int32_t readAccRegForTest( + uint64_t reg_idx, uint32_t row, uint32_t col) const; + RegVal readTokenForTest(uint64_t token_idx) const; + Tick tokenReadyTickForTest(uint64_t token_idx, uint64_t target) const; + void scheduleMemoryTimingForTest( + MemPort port, bool is_store, uint64_t requests); + void scheduleMemoryTimingForTest(MemPort port, bool is_store, + Addr base, Addr stride, uint32_t rows, uint32_t cols, + ElemWidth elem_width, bool transpose); +#endif + + private: + struct DataState + { + uint32_t tileM = 0; + uint32_t tileK = 0; + uint32_t tileN = 0; + std::array, MatrixRegCount> abRegs; + std::array, MatrixRegCount> accRegs; + std::vector tokens; + }; + + struct RegStatus + { + bool busy = false; + uint8_t pendingReaders = 0; + FuType producer = FuType::None; + }; + + struct FuStatus + { + bool busy = false; + uint8_t fifoIdx = 0; + }; + + struct TaskDesc + { + TaskOp op = TaskOp::Arith; + FuType fu = FuType::None; + MemPort memPort = MemPort::Num; + + bool destValid = false; + bool destIsAcc = false; + uint8_t destReg = 0; + + std::array srcValid = {}; + std::array srcIsAcc = {}; + std::array srcReg = {}; + std::array srcReadPending = {}; + + bool isStore = false; + bool coherent = true; + bool transpose = false; + uint64_t tokenIdx = 0; + uint8_t fifoIdx = 0; + uint8_t needMask = 0; + + Addr base = 0; + Addr stride = 0; + uint32_t rows = 0; + uint32_t cols = 0; + uint32_t depth = 0; + uint32_t bytesPerRow = 0; + ElemWidth elemWidth = ElemWidth::E8; + + uint64_t memoryBytes = 0; + uint64_t memoryRequests = 0; + uint64_t memoryBusBytes = 0; + }; + + struct MemoryRequestShape + { + uint64_t rows = 0; + uint64_t bytesPerRow = 0; + }; + + struct ControlState + { + std::array decodedQueue = {}; + uint8_t queueHead = 0; + uint8_t queueSize = 0; + uint8_t nextLoadFifoIdx = 0; + uint8_t nextComputeFifoIdx = 0; + uint8_t nextStoreFifoIdx = 0; + + std::array abRegs = {}; + std::array cRegs = {}; + std::array fus = {}; + + bool pendingStore = false; + + std::array sourceToPort = {}; + std::array sourceBusy = {}; + uint8_t firstMmuRequestIndex = 0; + uint8_t localMmuOutstanding = 0; + }; + + struct TimingState + { + Tick nextIssueTick = 0; + Tick pendingStoreReadyTick = 0; + Tick lastIssueTick = 0; + Tick lastCompletionTick = 0; + std::array fuReadyTick = {}; + std::array abWriteReadyTick = {}; + std::array cWriteReadyTick = {}; + std::array abReadBlockTick = {}; + std::array cReadBlockTick = {}; + std::array, TokenCount> tokenReadyTicks = {}; + Tick localMmuIssueSlotTick = 0; + uint32_t localMmuIssueSlotsUsed = 0; + Tick localMmuResponseSlotTick = 0; + Tick localMmuLastRequestTick = 0; + Tick localMmuLastResponseTick = 0; + std::array localMmuSourceReadyTick = {}; + }; + + DataState data; + ControlState control; + TimingConfig timingConfig; + TimingState timing; + Stats stats; + + void resetDataState(); + void resetControlState(); + void resetTimingState(); + void normalizeDataState(); + + RegVal &token(uint64_t idx); + const RegVal &token(uint64_t idx) const; + std::vector &abReg(uint64_t idx); + const std::vector &abReg(uint64_t idx) const; + std::vector &accReg(uint64_t idx); + const std::vector &accReg(uint64_t idx) const; + + TaskDesc makeMemoryTask(MemPort port, FuType fu, bool is_store, + bool is_acc, Addr base, Addr stride, uint32_t rows, uint32_t cols, + ElemWidth elem_width, bool transpose, uint8_t reg_idx) const; + TaskDesc makeMmaTask(uint8_t src_a_idx, uint8_t src_b_idx, + uint8_t dst_acc_idx, uint32_t rows, uint32_t cols, + uint32_t depth) const; + TaskDesc makeZeroTask(uint8_t reg_idx, bool is_acc) const; + TaskDesc makeReleaseTask(uint64_t token_idx) const; + + Fault executeLoad(ExecContext *xc, const TaskDesc &task); + Fault executeStore(ExecContext *xc, const TaskDesc &task); + void executeZero(const TaskDesc &task); + void executeMma(const TaskDesc &task); + + Tick scheduleTimingTask(ExecContext *xc, const TaskDesc &task); + Tick scheduleMemoryPipeline( + ExecContext *xc, const TaskDesc &task, Tick issue_tick); + Tick taskLatencyTicks(ExecContext *xc, const TaskDesc &task) const; + Tick taskSourceReadDoneTick( + ExecContext *xc, const TaskDesc &task, Tick issue_tick, + Tick complete_tick, size_t src_idx) const; + Tick cpuCycleTicks(ExecContext *xc) const; + Tick cyclesToTicks(ExecContext *xc, uint64_t cycles) const; + Tick peekLocalMmuIssueSlot(ExecContext *xc, Tick earliest_tick) const; + Tick reserveLocalMmuIssueSlot(ExecContext *xc, Tick earliest_tick); + Tick reserveLocalMmuResponseSlot(ExecContext *xc, Tick earliest_tick); + uint8_t chooseTimingLocalMmuSource(Tick earliest_tick) const; + uint32_t timingLocalMmuOutstanding(Tick tick) const; + void retireReadyTokens(Tick now); + void enqueueTokenRelease(uint64_t token_idx, Tick ready_tick); + void delayPendingTokenEvents(Tick ready_tick); + Tick tokenTargetReadyTick(uint64_t token_idx, uint64_t target) const; + uint16_t pendingTokenEvents() const; + + void beginTask(const TaskDesc &task); + void completeTask(); + void abortTask(); + void assignFifoIdx(TaskDesc &task); + + bool queueFull() const; + bool queueEmpty() const; + void pushTask(const TaskDesc &task); + TaskDesc &headTask(); + const TaskDesc &headTask() const; + void popTask(); + + bool canIssue(const TaskDesc &task) const; + void reserveTask(const TaskDesc &task); + void releaseTask(const TaskDesc &task); + + bool regBusy(bool is_acc, uint8_t reg_idx) const; + bool regHasPendingReaders(bool is_acc, uint8_t reg_idx) const; + void reserveReg(bool is_acc, uint8_t reg_idx, FuType fu); + void releaseReg(bool is_acc, uint8_t reg_idx, FuType fu); + void addRegReader(bool is_acc, uint8_t reg_idx); + void removeRegReader(bool is_acc, uint8_t reg_idx); + + bool fuBusy(FuType fu) const; + void setFuBusy(FuType fu, bool busy); + + void recordIssue(const TaskDesc &task); + void recordCompletion(const TaskDesc &task); + void recordMemoryRequests(const TaskDesc &task); + MemPort arbitrateLocalMmuPort(MemPort requested); + uint8_t allocateLocalMmuSource(MemPort port); + void completeLocalMmuRequest(uint8_t source, bool is_store); + void releaseLocalMmuSource(uint8_t source); + + static uint8_t normalizeRegIdx(uint64_t reg_idx); + static uint32_t clampTileM(uint64_t value); + static uint32_t clampTileK(uint64_t value); + static uint32_t clampTileN(uint64_t value); + static uint64_t divCeil(uint64_t numerator, uint64_t denominator); + static MemoryRequestShape matrixMemoryRequestShape( + MemPort port, bool transpose, uint32_t rows, uint32_t cols, + ElemWidth width); + static uint64_t matrixMemoryRequests(MemPort port, bool transpose, + Addr base, Addr stride, uint32_t rows, uint32_t cols, + ElemWidth width); + static uint64_t rowRequests( + Addr base, Addr stride, uint64_t rows, uint64_t bytes_per_row); + static uint32_t elemWidthBytes(ElemWidth width); + static size_t fuIndex(FuType fu); + static size_t memPortIndex(MemPort port); + static size_t taskEventIndex(TaskEvent event); +}; + +} // namespace matrix +} // namespace gem5 + +#endif // __MATRIX_MATRIX_CONTROLLER_HH__ diff --git a/src/matrix/matrix_controller.test.cc b/src/matrix/matrix_controller.test.cc new file mode 100644 index 0000000000..0a58d030b5 --- /dev/null +++ b/src/matrix/matrix_controller.test.cc @@ -0,0 +1,385 @@ +#include + +#include "base/gtest/cur_tick_fake.hh" +#include "base/stats/group.hh" +#include "matrix/matrix_controller.hh" + +namespace gem5 +{ +namespace matrix +{ +namespace +{ + +GTestTickHandler tickHandler; + +MatrixController::TimingConfig +testTimingConfig() +{ + MatrixController::TimingConfig config; + config.issueIntervalCycles = 2; + config.zeroCycles = 3; + config.computeBaseCycles = 5; + config.computeReadCycles = 1; + config.releaseCycles = 7; + return config; +} + +statistics::Counter +value(const statistics::Scalar &stat) +{ + return stat.value(); +} + +} // namespace + +TEST(MatrixControllerTest, MmaUpdatesDataAndSchedulesFixedTiming) +{ + tickHandler.setCurTick(0); + statistics::Group root(nullptr); + MatrixController controller(&root); + controller.setTimingConfig(testTimingConfig()); + controller.setTileM(2); + controller.setTileN(2); + controller.setTileK(3); + + controller.zeroAcc(nullptr); + controller.writeABRegForTest(MatrixController::DefaultAReg, 0, 0, 1); + controller.writeABRegForTest(MatrixController::DefaultAReg, 0, 1, 2); + controller.writeABRegForTest(MatrixController::DefaultAReg, 0, 2, 3); + controller.writeABRegForTest(MatrixController::DefaultAReg, 1, 0, 4); + controller.writeABRegForTest(MatrixController::DefaultAReg, 1, 1, 5); + controller.writeABRegForTest(MatrixController::DefaultAReg, 1, 2, 6); + controller.writeABRegForTest(MatrixController::DefaultBReg, 0, 0, 7); + controller.writeABRegForTest(MatrixController::DefaultBReg, 0, 1, 8); + controller.writeABRegForTest(MatrixController::DefaultBReg, 0, 2, 9); + controller.writeABRegForTest(MatrixController::DefaultBReg, 1, 0, 10); + controller.writeABRegForTest(MatrixController::DefaultBReg, 1, 1, 11); + controller.writeABRegForTest(MatrixController::DefaultBReg, 1, 2, 12); + + controller.mmaccWB(nullptr); + + EXPECT_EQ(controller.readAccRegForTest(MatrixController::DefaultAccReg, 0, + 0), + 50); + EXPECT_EQ(controller.readAccRegForTest(MatrixController::DefaultAccReg, 0, + 1), + 68); + EXPECT_EQ(controller.readAccRegForTest(MatrixController::DefaultAccReg, 1, + 0), + 122); + EXPECT_EQ(controller.readAccRegForTest(MatrixController::DefaultAccReg, 1, + 1), + 167); + + const auto snapshot = controller.controlSnapshot(); + EXPECT_EQ(snapshot.decodedQueueSize, 0); + EXPECT_EQ(snapshot.fuBusyMask, 0); + EXPECT_EQ(snapshot.abPendingReaderMask, 0); + EXPECT_EQ(snapshot.cPendingReaderMask, 0); + EXPECT_EQ(snapshot.timingLastIssueTick, 3); + EXPECT_EQ(snapshot.timingLastCompletionTick, 8); + EXPECT_EQ(snapshot.timingNextIssueTick, 5); + + const auto &stats = controller.getStats(); + EXPECT_EQ(value(stats.tasksAccepted), 2); + EXPECT_EQ(value(stats.tasksIssued), 2); + EXPECT_EQ(value(stats.tasksCompleted), 2); + EXPECT_EQ(value(stats.zeroTasks), 1); + EXPECT_EQ(value(stats.mmaTasks), 1); + EXPECT_EQ(value(stats.timingTasks), 2); + EXPECT_EQ(value(stats.timingQueueTicks), 3); + EXPECT_EQ(value(stats.timingMaxQueueTicks), 3); + EXPECT_EQ(value(stats.timingBusyTicks), 8); + EXPECT_EQ(value(stats.timingLastIssueTick), 3); + EXPECT_EQ(value(stats.timingLastCompletionTick), 8); + EXPECT_EQ(stats.taskEvents[static_cast( + MatrixController::TaskEvent::ComputeIssue)], + 1); +} + +TEST(MatrixControllerTest, MmaTimingDoesNotScaleWithShape) +{ + tickHandler.setCurTick(0); + statistics::Group small_root(nullptr); + MatrixController small(&small_root); + small.setTimingConfig(testTimingConfig()); + small.setTileM(2); + small.setTileN(2); + small.setTileK(3); + small.mmaccWB(nullptr); + + tickHandler.setCurTick(0); + statistics::Group large_root(nullptr); + MatrixController large(&large_root); + large.setTimingConfig(testTimingConfig()); + large.setTileM(8); + large.setTileN(8); + large.setTileK(8); + large.mmaccWB(nullptr); + + EXPECT_EQ(small.controlSnapshot().timingLastIssueTick, 0); + EXPECT_EQ(large.controlSnapshot().timingLastIssueTick, 0); + EXPECT_EQ(small.controlSnapshot().timingLastCompletionTick, 5); + EXPECT_EQ(large.controlSnapshot().timingLastCompletionTick, 5); +} + +TEST(MatrixControllerTest, ReleaseWaitsForScheduledMatrixWork) +{ + tickHandler.setCurTick(0); + statistics::Group root(nullptr); + MatrixController controller(&root); + controller.setTimingConfig(testTimingConfig()); + controller.setTileM(2); + controller.setTileN(2); + controller.setTileK(3); + + controller.zeroAcc(nullptr); + controller.mmaccWB(nullptr); + controller.release(nullptr, 0); + + const auto snapshot = controller.controlSnapshot(); + EXPECT_EQ(snapshot.timingLastIssueTick, 8); + EXPECT_EQ(snapshot.timingLastCompletionTick, 15); + EXPECT_EQ(snapshot.timingNextIssueTick, 10); + EXPECT_EQ(snapshot.pendingTokenEvents, 1); + EXPECT_EQ(controller.readTokenForTest(0), 0); + + const auto &stats = controller.getStats(); + EXPECT_EQ(value(stats.releaseTasks), 1); + EXPECT_EQ(value(stats.tokenReleaseEvents), 1); + EXPECT_EQ(value(stats.timingTasks), 3); + EXPECT_EQ(value(stats.timingQueueTicks), 11); + EXPECT_EQ(value(stats.timingMaxQueueTicks), 8); + EXPECT_EQ(value(stats.timingBusyTicks), 15); + EXPECT_EQ(stats.taskEvents[static_cast( + MatrixController::TaskEvent::ReleaseIssue)], + 1); +} + +TEST(MatrixControllerTest, TimingConfigClampsZeroIssueInterval) +{ + tickHandler.setCurTick(0); + statistics::Group root(nullptr); + MatrixController controller(&root); + auto config = testTimingConfig(); + config.issueIntervalCycles = 0; + config.zeroCycles = 1; + controller.setTimingConfig(config); + + controller.zero(nullptr, 0, true); + controller.zero(nullptr, 1, false); + + const auto snapshot = controller.controlSnapshot(); + EXPECT_EQ(snapshot.timingLastIssueTick, 1); + EXPECT_EQ(snapshot.timingLastCompletionTick, 2); + EXPECT_EQ(snapshot.timingNextIssueTick, 2); + EXPECT_EQ(value(controller.getStats().timingQueueTicks), 1); +} + +TEST(MatrixControllerTest, MemoryPipelineLimitsOutstandingSources) +{ + tickHandler.setCurTick(0); + statistics::Group root(nullptr); + MatrixController controller(&root); + auto config = testTimingConfig(); + config.issueIntervalCycles = 1; + config.loadBaseCycles = 0; + config.localMmuArbCycles = 0; + config.l2RequestPipelineCycles = 0; + config.localMmuIssuePerCycle = MatrixController::LocalMmuSourceCount; + config.localMmuReadLatencyCycles = 10; + config.l2ResponsePipelineCycles = 1; + controller.setTimingConfig(config); + + controller.scheduleMemoryTimingForTest( + MatrixController::MemPort::A, false, + MatrixController::LocalMmuSourceCount + 1); + + const auto snapshot = controller.controlSnapshot(); + EXPECT_EQ(snapshot.timingLastIssueTick, 0); + EXPECT_EQ(snapshot.timingLocalMmuLastRequestTick, 11); + EXPECT_EQ(snapshot.timingLocalMmuLastResponseTick, 74); + EXPECT_EQ(snapshot.timingLastCompletionTick, 74); + EXPECT_EQ(snapshot.timingLocalMmuOutstanding, + MatrixController::LocalMmuSourceCount); + + const auto &stats = controller.getStats(); + EXPECT_EQ(value(stats.memoryPipelineRequests), 65); + EXPECT_EQ(value(stats.memoryPipelineReadResponses), 65); + EXPECT_EQ(value(stats.memoryPipelineWriteAcks), 0); + EXPECT_EQ(value(stats.memoryPipelineSourceStallTicks), 10); + EXPECT_EQ(value(stats.memoryPipelineRequestQueueTicks), 1); + EXPECT_EQ(value(stats.memoryPipelineResponseQueueTicks), 2069); + EXPECT_EQ(value(stats.memoryPipelineLastRequestTick), 11); + EXPECT_EQ(value(stats.memoryPipelineLastResponseTick), 74); + EXPECT_EQ(value(stats.memoryPipelineMaxOutstanding), + MatrixController::LocalMmuSourceCount); +} + +TEST(MatrixControllerTest, MemoryPipelineChoosesSourceAtIssueTick) +{ + tickHandler.setCurTick(0); + statistics::Group root(nullptr); + MatrixController controller(&root); + auto config = testTimingConfig(); + config.issueIntervalCycles = 1; + config.loadBaseCycles = 0; + config.localMmuArbCycles = 0; + config.l2RequestPipelineCycles = 0; + config.localMmuIssuePerCycle = 1; + config.localMmuReadLatencyCycles = 0; + config.l2ResponsePipelineCycles = 1; + controller.setTimingConfig(config); + + controller.scheduleMemoryTimingForTest( + MatrixController::MemPort::A, false, 2); + + const auto snapshot = controller.controlSnapshot(); + EXPECT_EQ(snapshot.timingLocalMmuLastRequestTick, 1); + EXPECT_EQ(snapshot.timingLocalMmuLastResponseTick, 1); + const auto &stats = controller.getStats(); + EXPECT_EQ(value(stats.memoryPipelineSourceStallTicks), 0); + EXPECT_EQ(value(stats.memoryPipelineRequestQueueTicks), 1); +} + +TEST(MatrixControllerTest, MemoryPipelineSeparatesWriteAckResponses) +{ + tickHandler.setCurTick(0); + statistics::Group root(nullptr); + MatrixController controller(&root); + auto config = testTimingConfig(); + config.issueIntervalCycles = 1; + config.storeBaseCycles = 0; + config.localMmuArbCycles = 0; + config.l2RequestPipelineCycles = 0; + config.localMmuIssuePerCycle = 2; + config.localMmuWriteAckLatencyCycles = 5; + config.l2ResponsePipelineCycles = 1; + controller.setTimingConfig(config); + + controller.scheduleMemoryTimingForTest( + MatrixController::MemPort::CStore, true, 2); + + const auto snapshot = controller.controlSnapshot(); + EXPECT_EQ(snapshot.timingLocalMmuLastRequestTick, 0); + EXPECT_EQ(snapshot.timingLocalMmuLastResponseTick, 6); + EXPECT_EQ(snapshot.timingLastCompletionTick, 6); + + const auto &stats = controller.getStats(); + EXPECT_EQ(value(stats.memoryPipelineRequests), 2); + EXPECT_EQ(value(stats.memoryPipelineReadResponses), 0); + EXPECT_EQ(value(stats.memoryPipelineWriteAcks), 2); + EXPECT_EQ(value(stats.memoryPipelineResponseQueueTicks), 1); + EXPECT_EQ(stats.memPortRequests[static_cast( + MatrixController::MemPort::CStore)], + 2); +} + +TEST(MatrixControllerTest, CStoreTimingUsesRoundedStridedRequestCount) +{ + tickHandler.setCurTick(0); + statistics::Group root(nullptr); + MatrixController controller(&root); + auto config = testTimingConfig(); + config.issueIntervalCycles = 1; + config.storeBaseCycles = 0; + config.localMmuArbCycles = 0; + config.l2RequestPipelineCycles = 0; + config.localMmuIssuePerCycle = MatrixController::LocalMmuSourceCount; + config.localMmuWriteAckLatencyCycles = 5; + config.l2ResponsePipelineCycles = 1; + controller.setTimingConfig(config); + + controller.scheduleMemoryTimingForTest( + MatrixController::MemPort::CStore, true, 0x103c, 0x80, 3, 2, + MatrixController::ElemWidth::E32, false); + + const auto &stats = controller.getStats(); + EXPECT_EQ(value(stats.memoryPipelineRequests), 8); + EXPECT_EQ(value(stats.memoryPipelineWriteAcks), 8); + EXPECT_EQ(stats.memPortRequests[static_cast( + MatrixController::MemPort::CStore)], + 8); +} + +TEST(MatrixControllerTest, ControlLocalMmuUsesHighestFreeSource) +{ + tickHandler.setCurTick(0); + statistics::Group root(nullptr); + MatrixController controller(&root); + + auto snapshot = controller.controlSnapshot(); + EXPECT_EQ(snapshot.nextLocalMmuSource, + MatrixController::LocalMmuSourceCount - 1); + + const uint8_t first = controller.allocateLocalMmuSourceForTest( + MatrixController::MemPort::A); + const uint8_t second = controller.allocateLocalMmuSourceForTest( + MatrixController::MemPort::B); + EXPECT_EQ(first, MatrixController::LocalMmuSourceCount - 1); + EXPECT_EQ(second, MatrixController::LocalMmuSourceCount - 2); + + snapshot = controller.controlSnapshot(); + EXPECT_EQ(snapshot.nextLocalMmuSource, + MatrixController::LocalMmuSourceCount - 3); + EXPECT_EQ(snapshot.localMmuOutstanding, 2); + EXPECT_TRUE(snapshot.localMmuBusySourceMask & (1ULL << first)); + EXPECT_TRUE(snapshot.localMmuBusySourceMask & (1ULL << second)); + + controller.releaseLocalMmuSourceForTest(first); + snapshot = controller.controlSnapshot(); + EXPECT_EQ(snapshot.nextLocalMmuSource, first); + EXPECT_EQ(snapshot.localMmuOutstanding, 1); + + const uint8_t reused = controller.allocateLocalMmuSourceForTest( + MatrixController::MemPort::CLoad); + EXPECT_EQ(reused, first); + EXPECT_EQ(value(controller.getStats().localMmuSourceAllocations), 3); + EXPECT_EQ(value(controller.getStats().localMmuSourceReleases), 1); +} + +TEST(MatrixControllerTest, SyncResetClearsPendingTokenEvents) +{ + tickHandler.setCurTick(0); + statistics::Group root(nullptr); + MatrixController controller(&root); + auto config = testTimingConfig(); + config.releaseCycles = 5; + controller.setTimingConfig(config); + + controller.release(nullptr, 3); + ASSERT_EQ(controller.controlSnapshot().pendingTokenEvents, 1); + + controller.syncReset(3); + + EXPECT_EQ(controller.controlSnapshot().pendingTokenEvents, 0); + EXPECT_EQ(controller.readTokenForTest(3), 0); + EXPECT_EQ(value(controller.getStats().tokenReleaseEvents), 1); +} + +TEST(MatrixControllerTest, LaterTaskDoesNotRetirePendingTokenRelease) +{ + tickHandler.setCurTick(0); + statistics::Group root(nullptr); + MatrixController controller(&root); + auto config = testTimingConfig(); + config.releaseCycles = 1; + controller.setTimingConfig(config); + + controller.release(nullptr, 0); + ASSERT_EQ(controller.tokenReadyTickForTest(0, 1), 1); + ASSERT_EQ(controller.readTokenForTest(0), 0); + + tickHandler.setCurTick(10); + controller.zeroAcc(nullptr); + + EXPECT_EQ(controller.readTokenForTest(0), 0); + EXPECT_EQ(controller.tokenReadyTickForTest(0, 1), 1); + + controller.retireReadyTokensUpTo(10); + EXPECT_EQ(controller.readTokenForTest(0), 1); +} + +} // namespace matrix +} // namespace gem5 diff --git a/src/matrix/matrix_stats_test_stub.cc b/src/matrix/matrix_stats_test_stub.cc new file mode 100644 index 0000000000..9bfedd90d8 --- /dev/null +++ b/src/matrix/matrix_stats_test_stub.cc @@ -0,0 +1,212 @@ +#include "base/statistics.hh" + +#include +#include +#include +#include + +#include "base/logging.hh" +#include "base/stats/info.hh" + +namespace gem5 +{ +namespace statistics +{ + +std::string Info::separatorString = "::"; +int Info::id_count = 0; + +Info::Info() + : flags(none), precision(-1), prereq(0), storageParams() +{ + id = id_count++; +} + +Info::~Info() +{ +} + +const StorageParams * +Info::getStorageParams() const +{ + return storageParams.get(); +} + +void +Info::setStorageParams(const StorageParams *const params) +{ + storageParams.reset(params); +} + +void +Info::setName(const std::string &stat_name, bool old_style) +{ + name = stat_name; +} + +bool +Info::less(Info *stat1, Info *stat2) +{ + return stat1->name < stat2->name; +} + +bool +Info::baseCheck() const +{ + return true; +} + +void +Info::enable() +{ +} + +void +VectorInfo::enable() +{ +} + +void +VectorDistInfo::enable() +{ +} + +void +Vector2dInfo::enable() +{ +} + +void +InfoAccess::setInfo(Group *parent, Info *info) +{ + _info = info; +} + +void +InfoAccess::setParams(const StorageParams *params) +{ + info()->setStorageParams(params); +} + +void +InfoAccess::setInit() +{ + info()->flags.set(init); +} + +Info * +InfoAccess::info() +{ + panic_if(_info == nullptr, "unit-test stat info is not initialized"); + return _info; +} + +const Info * +InfoAccess::info() const +{ + panic_if(_info == nullptr, "unit-test stat info is not initialized"); + return _info; +} + +bool +InfoAccess::newStyleStats() const +{ + return _info != nullptr; +} + +Group::Group(Group *parent, const char *name) + : mergedParent(nullptr) +{ + if (parent && name) { + parent->addStatGroup(name, this); + } else if (parent && !name) { + parent->mergeStatGroup(this); + } +} + +Group::~Group() +{ +} + +void +Group::regStats() +{ +} + +void +Group::resetStats() +{ + for (auto &stat : stats) { + stat->reset(); + } + for (auto &group : mergedStatGroups) { + group->resetStats(); + } + for (auto &group : statGroups) { + group.second->resetStats(); + } +} + +void +Group::preDumpStats() +{ +} + +void +Group::addStat(Info *info) +{ + stats.push_back(info); + if (mergedParent) { + mergedParent->addStat(info); + } +} + +void +Group::addStatGroup(const char *name, Group *block) +{ + panic_if(block == nullptr, "Can't add null stat group %s", name); + panic_if(block == this, "Stat group can't be added to itself"); + panic_if(statGroups.find(name) != statGroups.end(), + "Stats of the same group share the same name `%s`.\n", name); + statGroups[name] = block; +} + +const Info * +Group::resolveStat(std::string name) const +{ + for (auto &info : stats) { + if (info->name == name) { + return info; + } + } + return nullptr; +} + +void +Group::mergeStatGroup(Group *block) +{ + panic_if(block == nullptr, "No stat block provided"); + panic_if(block->mergedParent, "Stat group already merged"); + panic_if(block == this, "Stat group can't merge with itself"); + + mergedStatGroups.push_back(block); + for (auto &stat : block->stats) { + addStat(stat); + } + block->mergedParent = this; +} + +const std::map & +Group::getStatGroups() const +{ + return statGroups; +} + +const std::vector & +Group::getStats() const +{ + return stats; +} + +} // namespace statistics +} // namespace gem5