From c5d7502fa86d31da116039756881e39ff74210c6 Mon Sep 17 00:00:00 2001 From: sakata yuma Date: Fri, 24 Mar 2023 03:05:08 +0900 Subject: [PATCH] =?UTF-8?q?=E7=9C=A0=E3=81=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/ChonAgent.py | 132 ++++++++++++++++++ ..._client.py => custom_client_sakata1024.py} | 3 +- python/sample_trial.py | 20 ++- 3 files changed, 151 insertions(+), 4 deletions(-) create mode 100644 python/ChonAgent.py rename python/{sample_client.py => custom_client_sakata1024.py} (91%) diff --git a/python/ChonAgent.py b/python/ChonAgent.py new file mode 100644 index 0000000..4fc5abf --- /dev/null +++ b/python/ChonAgent.py @@ -0,0 +1,132 @@ +import random + +import mjx +import sys +from datetime import datetime + +from client.agent import CustomAgentBase + +class ChonAgent(CustomAgentBase): + def __init__(self): + super().__init__() + + def custom_act(self, obs: mjx.Observation) -> mjx.Action: + """盤面情報と取れる行動を受け取って,行動を決定して返す関数.参加者が各自で実装. + + Args: + obs (mjx.Observation): 盤面情報と取れる行動(obs.legal_actions()) + + Returns: + mjx.Action: 実際に取る行動 + """ + #print(obs.draws()[-1].type()) + legal_actions = obs.legal_actions() + if len(legal_actions) == 1: + #print("なんもしん") + return legal_actions[0] + + # if it can win, just win + win_actions = [a for a in legal_actions if a.type() in [mjx.ActionType.TSUMO, mjx.ActionType.RON]] + if len(win_actions) >= 1: + assert len(win_actions) == 1 + #print("かちました") + return win_actions[0] + + # if it can declare riichi, just declar riichi + riichi_actions = [a for a in legal_actions if a.type() == mjx.ActionType.RIICHI] + if len(riichi_actions) >= 1: + assert len(riichi_actions) == 1 + #print("立直") + return riichi_actions[0] + + # if it can apply chi/pon/open-kan, choose randomly + steal_actions = [ + a + for a in legal_actions + if a.type() in [mjx.ActionType.PON, mjx.ActionType.OPEN_KAN] + ] + if len(steal_actions) >= 1: + if a.tile().type() == 31 or a.tile().type() == 32 or a.tile().type() == 33: + #print("ぽんち") + return random.choice(steal_actions) + + + # discard an effective tile randomly + legal_discards = [ + a for a in legal_actions if a.type() in [mjx.ActionType.DISCARD, mjx.ActionType.TSUMOGIRI] + ] + + moji_discards = [a for a in legal_discards if a.tile().num() == None] + + if len(moji_discards) > 0: + for discard in moji_discards: + dis_tile = discard.tile() + count = 0 + for tile in obs.curr_hand().closed_tiles(): + if dis_tile.type() == tile.type(): + count = count + 1 + if count <= 1: + #print("1文字ツモ") + return discard + for discard in moji_discards: + dis_tile = discard.tile() + count = 0 + for tile in obs.curr_hand().closed_tiles(): + if dis_tile.type() == tile.type(): + count = count + 1 + if count <= 2: + #print("2文字ツモ") + return discard + + effective_discard_types = obs.curr_hand().effective_discard_types() + effective_discards = [ + a for a in legal_discards if a.tile().type() in effective_discard_types + ] + + if len(effective_discards) == 0: + tsumogiri_action = [a for a in legal_actions if a.type() == mjx.ActionType.TSUMOGIRI] + #print("ツモぎり") + return tsumogiri_action[0] + + more_effective_discards = [] + for discard in effective_discards: + find = False + for tile in obs.curr_hand().closed_tiles(): + if discard.tile().id() == tile.id(): + continue + if discard.tile().type()-1 <= tile.type() <= discard.tile().type()+1 and tile.num() + discard.tile().num() != 10: + find = True + break + if find == False: + more_effective_discards.append(discard) + + #print([a.tile().type() for a in effective_discards]) + #print([a.tile().type() for a in more_effective_discards]) + + use_discards = [a for a in more_effective_discards if not a.tile().is_red()] + if len(use_discards) == 0: + use_discards = more_effective_discards + + discard_1or9 = [a for a in use_discards if a.tile().num() == 1 or a.tile().num() == 9] + if len(discard_1or9) > 0: + #print("1or9ぎり") + return discard_1or9[0] + + if len(use_discards) > 0: + #print("ドラなし効果切り") + return use_discards[0] + + if len(more_effective_discards) > 0: + #print("ドラあり効果切り") + return more_effective_discards[0] + + discard_1or9 = [a for a in effective_discards if a.tile().num() == 1 or a.tile().num() == 9] + if len(discard_1or9) > 0: + #print("あきらめ1or9ぎり") + return discard_1or9[0] + + #print("あきらめ切り") + return effective_discards[0] + + + \ No newline at end of file diff --git a/python/sample_client.py b/python/custom_client_sakata1024.py similarity index 91% rename from python/sample_client.py rename to python/custom_client_sakata1024.py index 62dd582..ca17f39 100644 --- a/python/sample_client.py +++ b/python/custom_client_sakata1024.py @@ -4,6 +4,7 @@ from client.client import SocketIOClient from client.agent import CustomAgentBase +from ChonAgent import ChonAgent # CustomAgentBase を継承して, # custom_act()を編集して麻雀AIを実装してください. @@ -26,7 +27,7 @@ def custom_act(self, obs: mjx.Observation) -> mjx.Action: if __name__ == "__main__": # 4人で対局する場合は,4つのSocketIOClientで同一のサーバーに接続する. - my_agent = MyAgent() # 参加者が実装したプレイヤーをインスタンス化 + my_agent = ChonAgent() # 参加者が実装したプレイヤーをインスタンス化 sio_client = SocketIOClient( ip='localhost', diff --git a/python/sample_trial.py b/python/sample_trial.py index 5fa95d3..c55cacf 100644 --- a/python/sample_trial.py +++ b/python/sample_trial.py @@ -9,9 +9,11 @@ import mjx import mjx.agents +from mjx.visualizer.visualizer import MahjongTable from server import convert_log from client.agent import CustomAgentBase +from ChonAgent import * # CustomAgentBase を継承して, @@ -38,9 +40,9 @@ def save_log(obs_dict, env, logs): if not os.path.exists(logdir): os.mkdir(logdir) - now = datetime.now().strftime('%Y%m%d%H%M%S%f') + now = datetime.now().strftime('%Y%m%d%H%M/%S%f') - os.mkdir(os.path.join(logdir, now)) + os.makedirs(os.path.join(logdir, now)) for player_id, obs in obs_dict.items(): with open(os.path.join(logdir, now, f"{player_id}.json"), "w") as f: json.dump(json.loads(obs.to_json()), f) @@ -74,7 +76,7 @@ def save_log(obs_dict, env, logs): } agents = [ - MyAgent(), # 自作Agent + ChonAgent(), # 自作Agent mjx.agents.ShantenAgent(), # mjxに実装されているAgent mjx.agents.ShantenAgent(), # mjxに実装されているAgent mjx.agents.ShantenAgent(), # mjxに実装されているAgent @@ -84,6 +86,8 @@ def save_log(obs_dict, env, logs): env_ = mjx.MjxEnv() obs_dict = env_.reset() + sums = {'player_3': 0, 'player_2': 0, 'player_1': 0, 'player_0': 0} + logs = convert_log.ConvertLog() for _ in range(n_games): while not env_.done(): @@ -93,8 +97,18 @@ def save_log(obs_dict, env, logs): obs_dict = env_.step(actions) if len(obs_dict.keys())==4: logs.add_log(obs_dict) + if env_.done("round"): + print("===FINISH===") returns = env_.rewards() + sums['player_0'] += returns['player_0'] + sums['player_1'] += returns['player_1'] + sums['player_2'] += returns['player_2'] + sums['player_3'] += returns['player_3'] if logging: save_log(obs_dict, env_, logs) + + table = MahjongTable.from_proto(env_.state().to_proto()) + print(sums) + print([a.score for a in table.players]) print("game has ended")