Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions python/ChonAgent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import random

import mjx
import sys
from datetime import datetime

from client.agent import CustomAgentBase

class ChonAgent(CustomAgentBase):
def __init__(self):
super().__init__()

def custom_act(self, obs: mjx.Observation) -> mjx.Action:
"""盤面情報と取れる行動を受け取って,行動を決定して返す関数.参加者が各自で実装.

Args:
obs (mjx.Observation): 盤面情報と取れる行動(obs.legal_actions())

Returns:
mjx.Action: 実際に取る行動
"""
#print(obs.draws()[-1].type())
legal_actions = obs.legal_actions()
if len(legal_actions) == 1:
#print("なんもしん")
return legal_actions[0]

# if it can win, just win
win_actions = [a for a in legal_actions if a.type() in [mjx.ActionType.TSUMO, mjx.ActionType.RON]]
if len(win_actions) >= 1:
assert len(win_actions) == 1
#print("かちました")
return win_actions[0]

# if it can declare riichi, just declar riichi
riichi_actions = [a for a in legal_actions if a.type() == mjx.ActionType.RIICHI]
if len(riichi_actions) >= 1:
assert len(riichi_actions) == 1
#print("立直")
return riichi_actions[0]

# if it can apply chi/pon/open-kan, choose randomly
steal_actions = [
a
for a in legal_actions
if a.type() in [mjx.ActionType.PON, mjx.ActionType.OPEN_KAN]
]
if len(steal_actions) >= 1:
if a.tile().type() == 31 or a.tile().type() == 32 or a.tile().type() == 33:
#print("ぽんち")
return random.choice(steal_actions)


# discard an effective tile randomly
legal_discards = [
a for a in legal_actions if a.type() in [mjx.ActionType.DISCARD, mjx.ActionType.TSUMOGIRI]
]

moji_discards = [a for a in legal_discards if a.tile().num() == None]

if len(moji_discards) > 0:
for discard in moji_discards:
dis_tile = discard.tile()
count = 0
for tile in obs.curr_hand().closed_tiles():
if dis_tile.type() == tile.type():
count = count + 1
if count <= 1:
#print("1文字ツモ")
return discard
for discard in moji_discards:
dis_tile = discard.tile()
count = 0
for tile in obs.curr_hand().closed_tiles():
if dis_tile.type() == tile.type():
count = count + 1
if count <= 2:
#print("2文字ツモ")
return discard

effective_discard_types = obs.curr_hand().effective_discard_types()
effective_discards = [
a for a in legal_discards if a.tile().type() in effective_discard_types
]

if len(effective_discards) == 0:
tsumogiri_action = [a for a in legal_actions if a.type() == mjx.ActionType.TSUMOGIRI]
#print("ツモぎり")
return tsumogiri_action[0]

more_effective_discards = []
for discard in effective_discards:
find = False
for tile in obs.curr_hand().closed_tiles():
if discard.tile().id() == tile.id():
continue
if discard.tile().type()-1 <= tile.type() <= discard.tile().type()+1 and tile.num() + discard.tile().num() != 10:
find = True
break
if find == False:
more_effective_discards.append(discard)

#print([a.tile().type() for a in effective_discards])
#print([a.tile().type() for a in more_effective_discards])

use_discards = [a for a in more_effective_discards if not a.tile().is_red()]
if len(use_discards) == 0:
use_discards = more_effective_discards

discard_1or9 = [a for a in use_discards if a.tile().num() == 1 or a.tile().num() == 9]
if len(discard_1or9) > 0:
#print("1or9ぎり")
return discard_1or9[0]

if len(use_discards) > 0:
#print("ドラなし効果切り")
return use_discards[0]

if len(more_effective_discards) > 0:
#print("ドラあり効果切り")
return more_effective_discards[0]

discard_1or9 = [a for a in effective_discards if a.tile().num() == 1 or a.tile().num() == 9]
if len(discard_1or9) > 0:
#print("あきらめ1or9ぎり")
return discard_1or9[0]

#print("あきらめ切り")
return effective_discards[0]



Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from client.client import SocketIOClient
from client.agent import CustomAgentBase
from ChonAgent import ChonAgent

# CustomAgentBase を継承して,
# custom_act()を編集して麻雀AIを実装してください.
Expand All @@ -26,7 +27,7 @@ def custom_act(self, obs: mjx.Observation) -> mjx.Action:

if __name__ == "__main__":
# 4人で対局する場合は,4つのSocketIOClientで同一のサーバーに接続する.
my_agent = MyAgent() # 参加者が実装したプレイヤーをインスタンス化
my_agent = ChonAgent() # 参加者が実装したプレイヤーをインスタンス化

sio_client = SocketIOClient(
ip='localhost',
Expand Down
20 changes: 17 additions & 3 deletions python/sample_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@

import mjx
import mjx.agents
from mjx.visualizer.visualizer import MahjongTable

from server import convert_log
from client.agent import CustomAgentBase
from ChonAgent import *


# CustomAgentBase を継承して,
Expand All @@ -38,9 +40,9 @@ def save_log(obs_dict, env, logs):
if not os.path.exists(logdir):
os.mkdir(logdir)

now = datetime.now().strftime('%Y%m%d%H%M%S%f')
now = datetime.now().strftime('%Y%m%d%H%M/%S%f')

os.mkdir(os.path.join(logdir, now))
os.makedirs(os.path.join(logdir, now))
for player_id, obs in obs_dict.items():
with open(os.path.join(logdir, now, f"{player_id}.json"), "w") as f:
json.dump(json.loads(obs.to_json()), f)
Expand Down Expand Up @@ -74,7 +76,7 @@ def save_log(obs_dict, env, logs):
}

agents = [
MyAgent(), # 自作Agent
ChonAgent(), # 自作Agent
mjx.agents.ShantenAgent(), # mjxに実装されているAgent
mjx.agents.ShantenAgent(), # mjxに実装されているAgent
mjx.agents.ShantenAgent(), # mjxに実装されているAgent
Expand All @@ -84,6 +86,8 @@ def save_log(obs_dict, env, logs):
env_ = mjx.MjxEnv()
obs_dict = env_.reset()

sums = {'player_3': 0, 'player_2': 0, 'player_1': 0, 'player_0': 0}

logs = convert_log.ConvertLog()
for _ in range(n_games):
while not env_.done():
Expand All @@ -93,8 +97,18 @@ def save_log(obs_dict, env, logs):
obs_dict = env_.step(actions)
if len(obs_dict.keys())==4:
logs.add_log(obs_dict)
if env_.done("round"):
print("===FINISH===")
returns = env_.rewards()
sums['player_0'] += returns['player_0']
sums['player_1'] += returns['player_1']
sums['player_2'] += returns['player_2']
sums['player_3'] += returns['player_3']
if logging:
save_log(obs_dict, env_, logs)

table = MahjongTable.from_proto(env_.state().to_proto())
print(sums)
print([a.score for a in table.players])
print("game has ended")