Skip to content

Commit 791d38f

Browse files
committed
Still playing with simple_arm64
1 parent 6ed0e1a commit 791d38f

File tree

2 files changed

+197
-44
lines changed

2 files changed

+197
-44
lines changed

tests/test_simple_arm64.py

Lines changed: 87 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,22 @@ def __init__(self, instr_to_test, expected_result=None, immediat_accepted=None,
3535
self.expected_result = expected_result
3636
self.must_fail = must_fail
3737
self.debug = debug
38+
self.callargs = None
39+
3840

3941
def __call__(self, *args):
42+
assert args is not None
43+
self.callargs = args
44+
return self
45+
46+
def __repr__(self):
47+
if self.must_fail:
48+
return "MustFail:{0}{1}".format(self.instr_to_test.__name__, self.callargs)
49+
return "{0}{1}".format(self.instr_to_test.__name__, self.callargs)
50+
51+
def dotest(self):
52+
assert self.callargs is not None
53+
args = self.callargs
4054
try:
4155
if self.debug:
4256
import pdb;pdb.set_trace()
@@ -70,6 +84,7 @@ def __call__(self, *args):
7084
raise AssertionError("Not all bytes have been used by the disassembler")
7185
self.compare_mnemo(capres)
7286
self.compare_args(args, capres)
87+
return True
7388

7489
def compare_mnemo(self, capres):
7590
expected = self.instr_to_test.__name__.lower()
@@ -80,9 +95,15 @@ def compare_mnemo(self, capres):
8095

8196
def compare_args(self, args, capres):
8297
capres_op = list(capres.operands)
83-
if len(args) != len(capres_op):
84-
raise AssertionError("Expected {0} operands got {1}".format(len(args), len(capres_op)))
85-
for op_args, cap_op in zip(args, capres_op):
98+
# We may have != number of operand as shift are:
99+
# - arguments for simple_arm64
100+
# - atribute of immediat for capstone
101+
if not len(capres_op) <= len(args):
102+
raise AssertionError("Expected at most {0} operands got {1}".format(len(args), len(capres_op)))
103+
104+
opargit = iter(args) # allow manually using next() to get next simple_arm64 arg for shift compare
105+
# capres_op must be first in zip (as its smaller) or last next(opargit) will be consommed by zip
106+
for cap_op, op_args in zip(capres_op, opargit):
86107
if isinstance(op_args, str): # Register
87108
if cap_op.type != capstone.arm64.ARM64_OP_REG:
88109
raise AssertionError("Expected args {0} operands got {1}".format(op_args, capres_op))
@@ -91,9 +112,39 @@ def compare_args(self, args, capres):
91112
elif isinstance(op_args, int_types):
92113
if (op_args != cap_op.imm) and not (self.immediat_accepted and self.immediat_accepted == cap_op.imm):
93114
raise AssertionError("Expected Immediat <{0}> got {1}".format(op_args, cap_op.imm))
115+
cap_shift = cap_op.shift
116+
if not (cap_shift.type == cap_shift.value == 0):
117+
self.compare_shift(next(opargit), cap_shift)
94118
else:
95119
raise ValueError("Unknow argument {0} of type {1}".format(op_args, type(op_args)))
96120

121+
# Check that no argument were unused in args
122+
# As args + shift should perfectly match the capres_op
123+
sentinel = object()
124+
nextarg = next(opargit, sentinel)
125+
if nextarg != sentinel:
126+
# Ignore a leading LSL #0 shift, as it should be authorized but not displayed by disassembler
127+
shift = Shift.parse(nextarg)
128+
if not (shift.type == "LSL" and shift.value == 0):
129+
raise ValueError("Non consomated argument: {0} (probable non-encoded shift)".format(nextarg))
130+
131+
SHIFT_TYPE_TO_CAPSTONE = {
132+
"LSL": capstone.arm64.ARM64_SFT_LSL,
133+
"LSR": capstone.arm64.ARM64_SFT_LSR,
134+
"ASR": capstone.arm64.ARM64_SFT_ASR,
135+
"ROR": capstone.arm64.ARM64_SFT_ROR,
136+
# "MSL": apstone.arm64.ARM64_SFT_MSL # Not yet used in PFW
137+
}
138+
139+
def compare_shift(self, shiftstr, cap_shift):
140+
shift = Shift.parse(shiftstr)
141+
if not self.SHIFT_TYPE_TO_CAPSTONE[shift.type] == cap_shift.type:
142+
raise ValueError("Shift type mismatch: expected {0} got {1}".format(shift.type, cap_shift.type))
143+
if not shift.value == cap_shift.value:
144+
raise ValueError("Shift value mismatch: expected {0} got {1}".format(shift.value, cap_shift.value))
145+
return True
146+
147+
97148
def test_shift_parsing():
98149
assert Shift.parse("LSL #0")
99150
assert Shift.parse("LSL #12")
@@ -110,32 +161,36 @@ def test_shift_parsing():
110161
assert not Shift.parse("LSX ##1")
111162
assert not Shift.parse("LSX #")
112163

113-
def test_assembler():
114-
CheckInstr(Add)('W0', 'W0', 0)
115-
CheckInstr(Add)('W1', 'W0', 0)
116-
CheckInstr(Add)('W30', 'W12', 0)
117-
CheckInstr(Add)('W0', 'W0', 1)
118-
119-
CheckInstr(Add)('X0', 'X0', 0)
120-
CheckInstr(Add)('X30', 'X12', 0)
121-
CheckInstr(Add)('X0', 'X0', 1)
122-
CheckInstr(Add)('X11', 'X12', 0x123)
123-
# CheckInstr(Add)('X11', 'X12', 0x123, "LSL #0")
124-
CheckInstr(Add)('X11', 'X12', 0x123, "LSL #12")
125-
126-
# Error test todo
127-
# CheckInstr(Add)('X11', 'W12', 0x123)
128-
with pytest.raises(ValueError):
129-
CheckInstr(Add)('BADREG', 'X12', 0)
130-
with pytest.raises(ValueError):
131-
CheckInstr(Add)('X11', 'X12', 0x123, "LSL #1234")
132-
133-
with pytest.raises(ValueError):
134-
# Immediat too big for encoding
135-
CheckInstr(Add)('X11', 'X12', 0x12345678)
136-
137-
CheckInstr(Ret)("X0")
138-
CheckInstr(Ret, expected_result="ret ")("X30")
139-
CheckInstr(Ret)()
140-
with pytest.raises(ValueError):
141-
CheckInstr(Ret)("W0")
164+
@pytest.mark.parametrize("checkinstr", [
165+
CheckInstr(Add)('W0', 'W0', 0),
166+
CheckInstr(Add)('W1', 'W0', 0),
167+
CheckInstr(Add)('W30', 'W12', 0),
168+
CheckInstr(Add)('W0', 'W0', 1),
169+
CheckInstr(Add)('X0', 'X0', 0),
170+
CheckInstr(Add)('X30', 'X12', 0),
171+
CheckInstr(Add)('X0', 'X0', 1),
172+
CheckInstr(Add)('X11', 'X12', 0x123),
173+
CheckInstr(Add)('X11', 'X12', 0x123, "LSL #0"),
174+
CheckInstr(Add)('X11', 'X12', 0x123, "LSL #12"),
175+
CheckInstr(Add, must_fail=True)('X11', 'W12', 0x123), # Bitness mismatch
176+
CheckInstr(Add, must_fail=True)('BADREG', 'X12', 0),
177+
CheckInstr(Add, must_fail=True)('X11', 'X12', 0x123, "LSL #1234"),
178+
CheckInstr(Add, must_fail=True)('X11', 'X12', 0x12345678),
179+
180+
CheckInstr(Movz)('X0', 0),
181+
CheckInstr(Movz)('X0', 0, "LSL #32"),
182+
CheckInstr(Movz)('X18', 0, "LSL #48"),
183+
CheckInstr(Movz)('W18', 0, "LSL #16"),
184+
CheckInstr(Movz, must_fail=True)('X0', 0, "LSL #12"), # Invalid LSL for MovWideImmediat
185+
CheckInstr(Movz, must_fail=True)('W0', 0, "LSL #32"),
186+
CheckInstr(Movz, must_fail=True)('X0', 0, "ROR #32"),
187+
188+
CheckInstr(Movk)('X0', 0x1234, "LSL #32"),
189+
CheckInstr(Movk)('X18', 0x5678, "LSL #48"),
190+
191+
CheckInstr(Ret)("X0"),
192+
CheckInstr(Ret, expected_result="ret ")("X30"),
193+
CheckInstr(Ret)(),
194+
], ids=CheckInstr.__repr__)
195+
def test_instruction_assembling(checkinstr):
196+
assert checkinstr.dotest()

windows/native_exec/simple_arm64.py

Lines changed: 110 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def is_register(self, arg, accept_sp):
113113
return (accept_sp and (arg in [SP, WSP])) or arg in ALL_REGISTER
114114

115115
@classmethod
116-
def is_imm12(self, arg):
116+
def is_imm(self, arg):
117117
try:
118118
value = int(arg)
119119
except (ValueError, TypeError):
@@ -127,9 +127,9 @@ def is_shift(self, arg):
127127

128128
@classmethod
129129
def gen(cls, **encoding_array):
130-
class GeneratedEncoding(cls):
130+
class GeneratedEncodingCls(cls):
131131
ENCODING_VALUES = encoding_array
132-
return GeneratedEncoding
132+
return GeneratedEncodingCls
133133

134134
# Instruction filing at instanciation
135135

@@ -214,18 +214,52 @@ def __init__(self, argsdict):
214214
if shift not in [("LSL", 0), ("LSL", 12)]:
215215
raise ValueError("Invalid shift for instruction: {0}".format(shift))
216216
if shift == ("LSL", 12):
217-
import pdb;pdb.set_trace()
218217
self.sh[:] = bytearray((1,))
219218

220219

221220
@classmethod
222221
def accept_arg(cls, argsdict):
223222
return (cls.is_register(argsdict[0], accept_sp=True) and
224223
cls.is_register(argsdict[1], accept_sp=True) and
225-
cls.is_imm12(argsdict[2]) and
224+
cls.is_imm(argsdict[2]) and
226225
cls.is_shift(argsdict.get(3)))
227226

228227

228+
# C4.1.93.6 Logical (immediate)
229+
# Wtf : https://kddnewton.com/2022/08/11/aarch64-bitmask-immediates.html
230+
231+
class DataProcessingLogicalImmediate(DataProcessingImmediate):
232+
def __init__(self, argsdict):
233+
super(DataProcessingLogicalImmediate, self).__init__()
234+
self.sf = self.bits[31:32]
235+
self.opc = self.bits[29:31]
236+
self.bits[23:29] = bytearray(reversed((1, 0, 0, 1, 0, 0)))
237+
self.N = self.bits[22:23]
238+
self.immr = self.bits[16:22]
239+
self.imms = self.bits[10:16]
240+
self.rn = self.bits[5:10]
241+
self.rd = self.bits[0:5]
242+
243+
self.setup_fixed_values()
244+
# Change instruction based of parameter
245+
self.setup_register(self.rd, argsdict[0])
246+
self.setup_register(self.rn, argsdict[1])
247+
self.setup_bitmask_imm(self.imm12, argsdict[2])
248+
249+
@classmethod
250+
def accept_arg(cls, argsdict):
251+
return (cls.is_register(argsdict[0], accept_sp=True) and
252+
cls.is_register(argsdict[1], accept_sp=True) and
253+
cls.is_bitmask_imm(argsdict[2]))
254+
255+
@classmethod
256+
def is_bitmask_imm(*args, **kwargs):
257+
raise NotImplementedError("is_bitmask_imm")
258+
259+
def setup_bitmask_imm(*args, **kwargs):
260+
raise NotImplementedError("setup_bitmask_imm")
261+
262+
229263
class MovWideImmediat(DataProcessingImmediate):
230264
def __init__(self, argsdict):
231265
super(MovWideImmediat, self).__init__()
@@ -242,13 +276,23 @@ def __init__(self, argsdict):
242276
self.setup_register(self.rd, argsdict[0])
243277
self.setup_immediat(self.imm16, argsdict[1])
244278

245-
assert argsdict.get(3) is None, "SHIFT NOT IMPLEMENTED YET"
279+
shift = Shift.parse(argsdict.get(2))
280+
if not shift:
281+
return
282+
if shift.type != "LSL":
283+
raise ValueError("Invalid shift type for {0} : {1}".format(type(self).__name__, shift.value))
284+
if shift.value not in (0, 16 ,32, 48):
285+
raise ValueError("Invalid shift value for {0} : {1}".format(type(self).__name__, shift.value))
286+
if self.bitness == 32 and shift.value > 16:
287+
raise ValueError("Invalid shift value for 32bits encoding of {0} : {1}".format(type(self).__name__, shift.value))
288+
289+
self.setup_immediat(self.hw, shift.value // 16)
246290

247291

248292
@classmethod
249293
def accept_arg(cls, argsdict):
250294
return (cls.is_register(argsdict[0], accept_sp=True) and
251-
cls.is_imm12(argsdict[1]) and
295+
cls.is_imm(argsdict[1]) and
252296
cls.is_shift(argsdict.get(2)))
253297

254298

@@ -293,24 +337,71 @@ def accept_arg(cls, argsdict):
293337
class DataProcessingRegister(InstructionEncoding):
294338
def __init__(self):
295339
super(DataProcessingRegister, self).__init__()
296-
self.bits[26:29] = bytearray((0,0,1))
297340
self.op0 = self.bits[30:31]
298341
self.op1 = self.bits[28:29]
299342
self.bits[25:28] = bytearray(reversed((1, 0, 1)))
300343
self.op2 = self.bits[21:25]
301344
self.op3 = self.bits[10:16]
302345

346+
class DataProcessingLogicalShiftedRegister(DataProcessingRegister):
347+
def __init__(self, argsdict):
348+
super(DataProcessingLogicalShiftedRegister, self).__init__()
349+
self.sf = self.bits[31:32]
350+
self.opc = self.bits[29:31]
351+
self.bits[24:29] = bytearray(reversed((0, 1, 0, 1, 0)))
352+
self.shift = self.bits[22:24]
353+
self.N = self.bits[21:22]
354+
self.rm = self.bits[16:21]
355+
self.imm6 = self.bits[10:16]
356+
self.rn = self.bits[5:10]
357+
self.rd = self.bits[0:5]
358+
359+
self.setup_fixed_values()
360+
# Change instruction based of parameter
361+
self.setup_register(self.rd, argsdict[0])
362+
self.setup_register(self.rn, argsdict[1])
363+
self.setup_register(self.rm, argsdict[2])
364+
365+
shift = Shift.parse(argsdict.get(3))
366+
if not shift:
367+
return
368+
# Is this mapping generic ? Store ir somewhere ?
369+
# Is the shift size logic repeatable and factorisable ?
370+
if self.bitness == 32 and shift.value > 31:
371+
raise ValueError("Invalid shift value for 32bits encoding of {0} : {1}".format(type(self).__name__, shift.value))
372+
373+
SHIFT_MAPPING = {"LSL": 0b00, "LSR": 0b01, "ASR": 0b10, "ROR": 0b11}
374+
self.setup_immediat(self.shift, SHIFT_MAPPING[shift.type])
375+
self.setup_immediat(self.imm6, shift.value)
376+
377+
378+
@classmethod
379+
def accept_arg(cls, argsdict):
380+
return (cls.is_register(argsdict[0]) and
381+
cls.is_register(argsdict[1]) and
382+
cls.is_register(argsdict[2]) and
383+
cls.is_shift(argsdict.get(3)))
384+
303385
# An instruction is a Name that can have multiple encoding
304386
# It's the class we instanciate to assemble instructions
305-
# Add X0, X0, IMM
306-
# Add X0, X0, X0
387+
# C6.2.270 ORR (immediate)
388+
# C6.2.271 ORR (shifted register)
389+
390+
# there also seem to exist "alias instructions" like "mov"
391+
# That just map to others instruction when specific condition are met on the params
392+
307393

308394
class Instruction(object):
309395
encoding = []
310396

311397
def __init__(self, *args):
312398
argsdict = dict(enumerate(args)) # Like a list but allow arg.get(4)
313-
for encodcls in self.encoding:
399+
for i, encodcls in enumerate(self.encoding):
400+
# Late rewrite of GeneratedEncodingCls classname for better message error
401+
if encodcls.__name__ == "GeneratedEncodingCls":
402+
encodcls.__name__ = "{0}Encoding{1}".format(type(self).__name__, i)
403+
404+
314405
if encodcls.accept_arg(argsdict):
315406
self.encoded = encodcls(argsdict)
316407
return
@@ -346,10 +437,17 @@ class Ret(Instruction):
346437

347438
# C6.2.254
348439

349-
class MovZ(Instruction):
440+
class Movz(Instruction):
350441
encoding = [MovWideImmediat.gen(opc=0b10)]
351442

443+
class Movk(Instruction):
444+
encoding = [MovWideImmediat.gen(opc=0b11)]
352445

446+
# The encoding for "mov reg, reg" :D
447+
# C6.2.271
448+
# Todo: Instruction like "mov" that dispatch to other instruction encoding based on more precise condition on param ?
449+
class Orr(Instruction):
450+
encoding = [DataProcessingLogicalShiftedRegister.gen(opc=0b01)]
353451

354452
class MultipleInstr(object):
355453
INSTRUCTION_SIZE = 4

0 commit comments

Comments
 (0)