Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 35 additions & 57 deletions examples/cpu/x86/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@

from lighthouse import dialects as lh_dialects
from lighthouse.execution.runner import Runner
from lighthouse.pipeline.driver import TransformDriver
from lighthouse.pipeline.descriptor import Descriptor
from lighthouse.pipeline.driver import PipelineDriver
from lighthouse.utils.numpy import numpy_to_mlir_type
from lighthouse.pipeline.helper import apply_registered_pass
import lighthouse.utils as lh_utils
from lighthouse import schedule as lh_schedule
import lighthouse.schedule.x86 as lh_schedule_x86
Expand Down Expand Up @@ -65,6 +65,8 @@ def __init__(self, M: int, N: int, K: int, dtype=np.float32, tile_size: int = 32
self.K = K
self.dtype = dtype
self.tile_size = tile_size
self.mod = ir.Module.create()
self.context = self.mod.context

@cached_property
def _input_arrays(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
Expand All @@ -88,9 +90,7 @@ def get_complexity(self) -> tuple[int, int, int]:
return (flop_count, memory_reads, memory_writes)

def payload_module(self) -> ir.Module:
mod = ir.Module.create()

with ir.InsertionPoint(mod.body):
with ir.InsertionPoint(self.mod.body):
mlir_dtype = numpy_to_mlir_type(self.dtype)

def tensor_t(shape, dtype=mlir_dtype):
Expand Down Expand Up @@ -118,47 +118,41 @@ def payload(A, B, C):
None, matmul, C, restrict=True, writable=True
)

return mod
return self.mod

def schedule_modules(
def get_pipeline(
self,
stop_at_stage: Optional[str] = None,
parameters: Optional[dict] = None,
) -> list[ir.Module]:
scheds = []
) -> PipelineDriver:
scheds = PipelineDriver(self.context)

# Insert performance measurements.
scheds.append(Runner.get_bench_wrapper_schedule(self.payload_function_name))
scheds.add_transform(
Runner.get_bench_wrapper_schedule(self.payload_function_name)
)

if stop_at_stage == "initial":
return scheds

# GEMM block packing.
# Create cache-friendly access pattern across matmul tiles.
scheds.append(
scheds.add_transform(
lh_schedule.block_pack_matmuls(
block_factors=[self.tile_size, self.tile_size, self.tile_size],
rhs_transpose_outer_block=True,
rhs_transpose_inner_block=False,
)
)
scheds.append(lh_schedule_x86.lower_packs_unpacks(self.tile_size))
scheds.add_transform(lh_schedule_x86.lower_packs_unpacks(self.tile_size))

# Convert to category ops for easier op matching.
with lh_schedule.schedule_boilerplate() as (sched, named_seq):
ops = lh_transform.match_op(named_seq.bodyTarget, "func.func")
transform.apply_registered_pass(
transform.any_op_t(),
ops,
scheds.add_pass(
Descriptor(
"linalg-morph-ops",
options={
"named-to-category": True,
"generic-to-category": True,
},
opts={"named-to-category": True, "generic-to-category": True},
)
lh_transform.cleanup(named_seq.bodyTarget)
transform.yield_()
scheds.append(sched)
)

# GEMM cache tiling.
# Create memory friendly access pattern.
Expand All @@ -171,11 +165,11 @@ def schedule_modules(
)
transform.yield_()
transform.yield_()
scheds.append(sched)
scheds.add_transform(sched)

# Fold extra parallel outer unit dims before further tiling to help later
# vectorization rewrites to recognize ops.
scheds.append(lh_schedule.linalg_contract_fold_unit_dims())
scheds.add_transform(lh_schedule.linalg_contract_fold_unit_dims())

# GEMM register tiling.
# Ensure that computation can fit into vector registers.
Expand All @@ -189,7 +183,7 @@ def schedule_modules(
reg_peel_loops.append(1)
if self.tile_size % reg_tile_m != 0:
reg_peel_loops.append(0)
scheds.append(
scheds.add_transform(
lh_schedule.tile_ops(
gemm_op,
tile_sizes=[reg_tile_batch, reg_tile_m, reg_tile_n, reg_tile_k],
Expand All @@ -209,7 +203,7 @@ def schedule_modules(
reg_tile_n // reg_unroll_n,
reg_tile_k // reg_unroll_k,
]
scheds.append(
scheds.add_transform(
lh_schedule.tile_ops(
gemm_op,
tile_sizes=[0, reg_unroll_m, reg_unroll_n, reg_unroll_k],
Expand All @@ -218,15 +212,15 @@ def schedule_modules(
)

# Further tiling into hardware-friendly sizes for vectorization.
scheds.append(lh_schedule.tile_ops("linalg.fill", tile_sizes=[1, 1, 1]))
scheds.append(lh_schedule.tile_ops("linalg.generic", tile_sizes=[1, 8]))
scheds.add_transform(lh_schedule.tile_ops("linalg.fill", tile_sizes=[1, 1, 1]))
scheds.add_transform(lh_schedule.tile_ops("linalg.generic", tile_sizes=[1, 8]))

if stop_at_stage == "tiled":
return scheds

# Vectorization.
scheds.append(lh_schedule.vectorize_linalg())
scheds.append(lh_schedule.hoist_loops())
scheds.add_transform(lh_schedule.vectorize_linalg())
scheds.add_transform(lh_schedule.hoist_loops())

with lh_schedule.schedule_boilerplate() as (sched, named_seq):
with ir.InsertionPoint(
Expand All @@ -235,46 +229,32 @@ def schedule_modules(
tensor.apply_patterns_tensor_fold_tensor_subset_ops_into_vector_transfers()
transform.apply_patterns_canonicalization()
transform.yield_()
scheds.append(sched)
scheds.add_transform(sched)

# Rewrite vector ops into x86-specific sequences.
scheds.append(lh_schedule.x86_vectorization())
scheds.add_transform(lh_schedule.x86_vectorization())

# Lower to memrefs.
scheds.append(lh_schedule.bufferize(deallocation_pipeline=True))
scheds.add_descriptor(Descriptor("bufferization.yaml"))
Comment thread
rengolin marked this conversation as resolved.
scheds.add_descriptor(Descriptor("bufferization_cleanup.yaml"))

# Apply x86 vectorization again as some patterns require memref abstraction.
scheds.append(lh_schedule.x86_vectorization())
scheds.add_transform(lh_schedule.x86_vectorization())
# Vectorize any remaining ops.
scheds.append(lh_schedule.vectorize_all())
scheds.add_transform(lh_schedule.vectorize_all())

# Cleanup vector ops.
with lh_schedule.schedule_boilerplate() as (sched, named_seq):
lh_transform.flatten_vector_ops(named_seq.bodyTarget)
lh_transform.cleanup(named_seq.bodyTarget)
transform.yield_()
scheds.append(sched)
scheds.add_transform(sched)

if stop_at_stage == "vectorized":
return scheds

# Lower to LLVM.
with lh_schedule.schedule_boilerplate() as (sched, named_seq):
target = named_seq.bodyTarget
target = apply_registered_pass(target, "convert-linalg-to-loops")
target = apply_registered_pass(target, "fold-memref-alias-ops")
target = apply_registered_pass(target, "expand-strided-metadata")
target = apply_registered_pass(target, "canonicalize")
target = apply_registered_pass(target, "convert-vector-to-scf")
target = apply_registered_pass(target, "lower-affine")
target = apply_registered_pass(target, "convert-scf-to-cf")
target = apply_registered_pass(target, "convert-vector-to-llvm")
target = apply_registered_pass(target, "convert-to-llvm")
target = apply_registered_pass(target, "reconcile-unrealized-casts")
lh_transform.cleanup(target)

transform.yield_()
scheds.append(sched)
scheds.add_descriptor(Descriptor("llvm_lowering.yaml"))

return scheds

Expand Down Expand Up @@ -363,9 +343,7 @@ def parse_cli():
sys.exit(1)

wload = Matmul(*args.sizes, dtype=in_dtype, tile_size=args.tile_size)
pipeline = TransformDriver(
wload.schedule_modules(stop_at_stage=args.dump_kernel)
)
pipeline = wload.get_pipeline(stop_at_stage=args.dump_kernel)
payload = pipeline.apply(wload.payload_module())

if args.dump_kernel or args.dump_schedule:
Expand Down
4 changes: 4 additions & 0 deletions lighthouse/pipeline/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ def add_stages(self, stages: list[Descriptor]) -> None:
for s in stages:
self.add_stage(s)

def add_descriptor(self, stage: Descriptor) -> None:
for s in PipelineDescriptor(stage).get_stages():
self.add_stage(s)

def apply(self, module: ir.Module, print_after_all: bool = False) -> ir.Module:
if module.context != self.context:
raise ValueError("Module context does not match driver context.")
Expand Down
Loading