Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added results/coding_ber_curve.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added results/equalization_eye_comparison.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added results/equalization_mse_curve.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
74 changes: 64 additions & 10 deletions src/part1_channel_coding.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ def hamming74_encode(bits):
if not np.all((bits == 0) | (bits == 1)):
raise ValueError('bits 只能包含 0 或 1')

# TODO: 将 bits reshape 为 (-1, 4),再与 HAMMING_G 相乘并对 2 取模。
raise NotImplementedError('请实现 Hamming(7,4) 编码')
data_blocks = bits.reshape(-1, 4)
encoded_blocks = np.mod(data_blocks.dot(HAMMING_G), 2)
return encoded_blocks.reshape(-1)


def hamming74_syndrome(codewords):
Expand All @@ -70,8 +71,8 @@ def hamming74_syndrome(codewords):
if codewords.shape[1] != 7:
raise ValueError('每个 Hamming(7,4) 码字长度必须为 7')

# TODO: 计算 s = r H^T mod 2。
raise NotImplementedError('请实现伴随式计算')
syndromes = np.mod(codewords.dot(HAMMING_H.T), 2)
return syndromes


def hamming74_decode(received):
Expand All @@ -94,8 +95,20 @@ def hamming74_decode(received):
if received.ndim != 1 or len(received) % 7 != 0:
raise ValueError('received 必须是一维数组,长度为 7 的倍数')

# TODO: 使用 hamming74_syndrome 完成单比特纠错,并返回前 4 个信息位。
raise NotImplementedError('请实现 Hamming(7,4) 译码')
received = received.reshape(-1, 7)
syndromes = hamming74_syndrome(received)
corrected = received.copy()

# 每个非零伴随式都应该对应 H 的某一列,定位单比特错误。
for idx, syndrome in enumerate(syndromes):
if np.any(syndrome):
matches = np.all(HAMMING_H.T == syndrome, axis=1)
if np.any(matches):
error_position = np.argmax(matches)
corrected[idx, error_position] ^= 1

decoded_bits = corrected[:, :4].reshape(-1)
return decoded_bits


def convolutional_encode(bits):
Expand All @@ -108,8 +121,16 @@ def convolutional_encode(bits):
if not np.all((bits == 0) | (bits == 1)):
raise ValueError('bits 只能包含 0 或 1')

# TODO: 选做任务,可参考课件第6章卷积码部分。
raise NotImplementedError('选做:请实现卷积码编码')
extended = np.concatenate([bits, np.zeros(2, dtype=int)])
state = np.zeros(3, dtype=int)
encoded = []

for bit in extended:
state = np.array([bit, state[0], state[1]], dtype=int)
encoded.append(int(state[0] ^ state[1] ^ state[2]))
encoded.append(int(state[0] ^ state[2]))

return np.array(encoded, dtype=int)


def viterbi_decode_hard(received_bits):
Expand All @@ -120,8 +141,41 @@ def viterbi_decode_hard(received_bits):
if len(received_bits) % 2 != 0:
raise ValueError('卷积码接收序列长度必须是 2 的倍数')

# TODO: 选做任务,可使用汉明距离作为路径度量。
raise NotImplementedError('选做:请实现 Viterbi 硬判决译码')
received = received_bits.reshape(-1, 2)
num_steps = received.shape[0]
inf = 1e9

metrics = np.full((num_steps + 1, 4), inf)
prev_state = np.zeros((num_steps + 1, 4), dtype=int)
prev_input = np.zeros((num_steps + 1, 4), dtype=int)
metrics[0, 0] = 0

for i in range(num_steps):
for state in range(4):
metric = metrics[i, state]
if metric >= inf:
continue

m1 = (state >> 1) & 1
m2 = state & 1
for bit in (0, 1):
next_state = ((bit << 1) | m1) & 0b11
expected = np.array([bit ^ m1 ^ m2, bit ^ m2], dtype=int)
branch_metric = int(np.sum(received[i] != expected))
new_metric = metric + branch_metric
if new_metric < metrics[i + 1, next_state]:
metrics[i + 1, next_state] = new_metric
prev_state[i + 1, next_state] = state
prev_input[i + 1, next_state] = bit

final_state = 0 if metrics[num_steps, 0] < inf else int(np.argmin(metrics[num_steps]))
decoded = np.zeros(num_steps, dtype=int)
state = final_state
for i in range(num_steps, 0, -1):
decoded[i - 1] = prev_input[i, state]
state = prev_state[i, state]

return decoded[:-2] if num_steps >= 2 else decoded


def run_coding_demo():
Expand Down
44 changes: 38 additions & 6 deletions src/part2_equalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,24 @@ def estimate_zf_equalizer(channel, num_taps):
if num_taps < 1:
raise ValueError('num_taps 必须为正整数')

# TODO: 构造卷积矩阵并求解 ZF 均衡器抽头。
raise NotImplementedError('请实现 ZF 均衡器估计')
num_taps = int(num_taps)
num_channel = len(channel)
conv_len = num_channel + num_taps - 1

# 构造卷积矩阵 A,使得 A @ taps = channel * taps
A = np.zeros((conv_len, num_taps), dtype=float)
for i in range(conv_len):
for j in range(num_taps):
k = i - j
if 0 <= k < num_channel:
A[i, j] = channel[k]

d = np.zeros(conv_len, dtype=float)
center = conv_len // 2
d[center] = 1.0

taps, *_ = np.linalg.lstsq(A, d, rcond=None)
return taps


def apply_fir_filter(signal, taps):
Expand All @@ -58,8 +74,8 @@ def apply_fir_filter(signal, taps):
if signal.ndim != 1 or taps.ndim != 1:
raise ValueError('signal 和 taps 必须是一维数组')

# TODO: 使用 np.convolve,并截取与 signal 等长的输出。
raise NotImplementedError('请实现 FIR 滤波')
filtered = np.convolve(signal, taps, mode='full')[: len(signal)]
return filtered


def lms_equalizer(rx_train, tx_train, num_taps, step_size=0.01):
Expand Down Expand Up @@ -89,8 +105,24 @@ def lms_equalizer(rx_train, tx_train, num_taps, step_size=0.01):
if num_taps < 1:
raise ValueError('num_taps 必须为正整数')

# TODO: 实现 LMS 自适应均衡训练。
raise NotImplementedError('请实现 LMS 均衡器')
num_taps = int(num_taps)
taps = np.zeros(num_taps, dtype=float)
taps[num_taps // 2] = 1.0
errors = np.zeros(len(rx_train), dtype=float)

for n in range(len(rx_train)):
x = np.zeros(num_taps, dtype=float)
for k in range(num_taps):
idx = n - k
if idx >= 0:
x[k] = rx_train[idx]

y = np.dot(taps, x)
e = tx_train[n] - y
taps = taps + step_size * e * x
errors[n] = e

return taps, errors


def run_equalization_demo():
Expand Down
Loading