Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 166 additions & 0 deletions fu/single/GepRTL.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
"""
==========================================================================
GepRTL.py
==========================================================================
GetElementPtr (GEP) functional unit for CGRA tile.

Supports 1D and 2D address generation:
- OPT_GEP: result = base(in0) + index(in1)
- OPT_GEP_CONST: result = base(const) + index(in0)
- OPT_GEP_2D: result = base(in0) + index0(in1) * stride + index1(in2)
- OPT_GEP_2D_CONST: result = base(const) + index0(in0) * stride + index1(in1)

For 2D operations, the stride is pre-configured via CMD_CONFIG_GEP_STRIDE
through the recv_from_ctrl_mem interface before execution begins.

Author : Shangkun Li
Date : March 31, 2026
"""

from pymtl3 import *
from ..basic.Fu import Fu
from ...lib.opt_type import *
from ...lib.cmd_type import *

class GepRTL(Fu):

def construct(s, CtrlPktType, num_inports, num_outports, vector_factor_power = 0):

super(GepRTL, s).construct(CtrlPktType, num_inports, num_outports, 1, vector_factor_power)

num_entries = 2
FuInType = mk_bits(clog2(num_inports + 1))
CountType = mk_bits(clog2(num_entries + 1))

s.in0 = Wire(FuInType)
s.in1 = Wire(FuInType)
s.in2 = Wire(FuInType)

idx_nbits = clog2(num_inports)
s.in0_idx = Wire(idx_nbits)
s.in1_idx = Wire(idx_nbits)
s.in2_idx = Wire(idx_nbits)

s.in0_idx //= s.in0[0:idx_nbits]
s.in1_idx //= s.in1[0:idx_nbits]
s.in2_idx //= s.in2[0:idx_nbits]

s.recv_all_val = Wire(1)

# Stride register for 2D GEP, configured via CMD_CONFIG_GEP_STRIDE.
s.stride = Wire(s.DataType)

@update
def comb_logic():

s.recv_all_val @= 0
# For pick input register
s.in0 @= 0
s.in1 @= 0
s.in2 @= 0
for i in range(num_inports):
s.recv_in[i].rdy @= b1(0)
for i in range(num_outports):
s.send_out[i].val @= 0
s.send_out[i].msg @= s.DataType()

s.recv_const.rdy @= 0
s.recv_opt.rdy @= 0

s.send_to_ctrl_mem.val @= 0
s.send_to_ctrl_mem.msg @= s.CgraPayloadType(0, 0, 0, 0, 0)
s.recv_from_ctrl_mem.rdy @= 0

# Handle CMD configuration from ctrl_mem.
if s.recv_from_ctrl_mem.val:
s.recv_from_ctrl_mem.rdy @= b1(1)

if s.recv_opt.val:
if s.recv_opt.msg.fu_in[0] != 0:
s.in0 @= zext(s.recv_opt.msg.fu_in[0] - 1, FuInType)
if s.recv_opt.msg.fu_in[1] != 0:
s.in1 @= zext(s.recv_opt.msg.fu_in[1] - 1, FuInType)
if s.recv_opt.msg.fu_in[2] != 0:
s.in2 @= zext(s.recv_opt.msg.fu_in[2] - 1, FuInType)

if s.recv_opt.val:

# ===== OPT_GEP: 1D GEP with two input operands =====
# result = base(in0) + index(in1)
if s.recv_opt.msg.operation == OPT_GEP:
s.send_out[0].msg.payload @= s.recv_in[s.in0_idx].msg.payload + \
s.recv_in[s.in1_idx].msg.payload
s.send_out[0].msg.predicate @= s.recv_in[s.in0_idx].msg.predicate & \
s.recv_in[s.in1_idx].msg.predicate & \
s.reached_vector_factor
s.recv_all_val @= s.recv_in[s.in0_idx].val & s.recv_in[s.in1_idx].val
s.send_out[0].val @= s.recv_all_val
s.recv_in[s.in0_idx].rdy @= s.recv_all_val & s.send_out[0].rdy
s.recv_in[s.in1_idx].rdy @= s.recv_all_val & s.send_out[0].rdy
s.recv_opt.rdy @= s.recv_all_val & s.send_out[0].rdy

# ===== OPT_GEP_CONST: 1D GEP with const base =====
# result = base(const) + index(in0)
elif s.recv_opt.msg.operation == OPT_GEP_CONST:
s.send_out[0].msg.payload @= s.recv_const.msg.payload + \
s.recv_in[s.in0_idx].msg.payload
s.send_out[0].msg.predicate @= s.recv_in[s.in0_idx].msg.predicate & \
s.reached_vector_factor
s.recv_all_val @= s.recv_in[s.in0_idx].val & s.recv_const.val
s.send_out[0].val @= s.recv_all_val
s.recv_in[s.in0_idx].rdy @= s.recv_all_val & s.send_out[0].rdy
s.recv_const.rdy @= s.recv_all_val & s.send_out[0].rdy
s.recv_opt.rdy @= s.recv_all_val & s.send_out[0].rdy

# ===== OPT_GEP_2D: 2D GEP with three input operands =====
# result = base(in0) + index0(in1) * stride + index1(in2)
elif s.recv_opt.msg.operation == OPT_GEP_2D:
s.send_out[0].msg.payload @= s.recv_in[s.in0_idx].msg.payload + \
s.recv_in[s.in1_idx].msg.payload * s.stride.payload + \
s.recv_in[s.in2_idx].msg.payload
s.send_out[0].msg.predicate @= s.recv_in[s.in0_idx].msg.predicate & \
s.recv_in[s.in1_idx].msg.predicate & \
s.recv_in[s.in2_idx].msg.predicate & \
s.reached_vector_factor
s.recv_all_val @= s.recv_in[s.in0_idx].val & \
s.recv_in[s.in1_idx].val & \
s.recv_in[s.in2_idx].val
s.send_out[0].val @= s.recv_all_val
s.recv_in[s.in0_idx].rdy @= s.recv_all_val & s.send_out[0].rdy
s.recv_in[s.in1_idx].rdy @= s.recv_all_val & s.send_out[0].rdy
s.recv_in[s.in2_idx].rdy @= s.recv_all_val & s.send_out[0].rdy
s.recv_opt.rdy @= s.recv_all_val & s.send_out[0].rdy

# ===== OPT_GEP_2D_CONST: 2D GEP with const base =====
# result = base(const) + index0(in0) * stride + index1(in1)
elif s.recv_opt.msg.operation == OPT_GEP_2D_CONST:
s.send_out[0].msg.payload @= s.recv_const.msg.payload + \
s.recv_in[s.in0_idx].msg.payload * s.stride.payload + \
s.recv_in[s.in1_idx].msg.payload
s.send_out[0].msg.predicate @= s.recv_in[s.in0_idx].msg.predicate & \
s.recv_in[s.in1_idx].msg.predicate & \
s.reached_vector_factor
s.recv_all_val @= s.recv_in[s.in0_idx].val & \
s.recv_in[s.in1_idx].val & \
s.recv_const.val
s.send_out[0].val @= s.recv_all_val
s.recv_in[s.in0_idx].rdy @= s.recv_all_val & s.send_out[0].rdy
s.recv_in[s.in1_idx].rdy @= s.recv_all_val & s.send_out[0].rdy
s.recv_const.rdy @= s.recv_all_val & s.send_out[0].rdy
s.recv_opt.rdy @= s.recv_all_val & s.send_out[0].rdy

else:
for j in range(num_outports):
s.send_out[j].val @= b1(0)
s.recv_opt.rdy @= 0
s.recv_in[s.in0_idx].rdy @= 0
s.recv_in[s.in1_idx].rdy @= 0

@update_ff
def update_stride():
if s.reset:
s.stride <<= s.DataType(0, 0)
else:
if s.recv_from_ctrl_mem.val & \
(s.recv_from_ctrl_mem.msg.cmd == CMD_CONFIG_GEP_STRIDE):
s.stride <<= s.recv_from_ctrl_mem.msg.data
Loading
Loading