Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
.DS_Store
./lightning_logs
json
/python/obs.npy
/python/actions.npy
17 changes: 17 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"cSpell.ignoreWords": [
"RIICHI",
"Shanten",
"TSUMO",
"argmax",
"dataloaders",
"inps",
"logdir",
"logit",
"optim",
"preds",
"proba",
"tenho",
"tgts"
]
}
26 changes: 26 additions & 0 deletions docker/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,29 @@ tqdm
python-socketio
eventlet
python-socketio[client]
aiohttp==3.8.4
aiosignal==1.3.1
async-timeout==4.0.2
attrs==22.2.0
cloudpickle==2.2.1
filelock==3.10.2
frozenlist==1.3.3
fsspec==2023.3.0
gym==0.26.2
gym-notices==0.0.8
Jinja2==3.1.2
lightning-utilities==0.8.0
MarkupSafe==2.1.2
mpmath==1.3.0
multidict==6.0.4
networkx==3.0
packaging==23.0
pytorch-lightning==2.0.0
PyYAML==6.0
torch
torchaudio
torchmetrics
torchvision
typing_extensions==4.5.0
sympy==1.11.1
yarl==1.8.2
27 changes: 27 additions & 0 deletions python/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from mjx import Observation, State, Action
import glob
import numpy as np
count = 0
files = glob.glob("./json/json/*")
obs_hist = []
action_hist = []
for file in files:
print("#" + str(count) + "Loading file....")
with open(file) as f:
lines = f.readlines()
for line in lines:
state = State(line)
for cpp_obs, cpp_act in state._cpp_obj.past_decisions():
obs = Observation._from_cpp_obj(cpp_obs)
feature = obs.to_features(feature_name="mjx-small-v0")
action = Action._from_cpp_obj(cpp_act)
action_idx = action.to_idx()
obs_hist.append(feature)
action_hist.append(action_idx)
count += 1
if count >= 1000:
break


np.save("obs.npy", np.stack(obs_hist))
np.save("actions.npy", np.array(action_hist, dtype=np.int32))
77 changes: 77 additions & 0 deletions python/custom_client_riku0801.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import torch
import mjx
from torch import optim, nn, utils, Tensor
import pytorch_lightning as pl
import mjx.agents
from client.agent import CustomAgentBase
from client.client import SocketIOClient

class MLP(pl.LightningModule):
def __init__(self, obs_size=544, n_actions=181, hidden_size=544):
super().__init__()
self.net = nn.Sequential(
nn.Linear(obs_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, n_actions),
)
self.loss_module = nn.CrossEntropyLoss()

def training_step(self, batch, batch_idx):
x, y = batch
preds = self.forward(x)
loss = self.loss_module(preds, y)
self.log("train_loss", loss)
return loss

def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=1e-3)
return optimizer

def forward(self, x):
return self.net(x.float())

model = MLP()
model.load_state_dict(torch.load('./model_0.pth'))

class MyAgent(CustomAgentBase):

def __init__(self) -> None:
super().__init__()

def custom_act(self, obs: mjx.Observation) -> mjx.Action:
legal_actions = obs.legal_actions()
if len(legal_actions) == 1:
return legal_actions[0]

for action in legal_actions:
if action.type() in [mjx.ActionType.TSUMO, mjx.ActionType.RON]:
return action
elif action.type() == mjx.ActionType.RIICHI:
return action

feature = obs.to_features(feature_name="mjx-small-v0")
with torch.no_grad():
action_logit = model(Tensor(feature.ravel()))
action_proba = torch.sigmoid(action_logit).numpy()

mask = obs.action_mask()
action_idx = (mask * action_proba).argmax()
return mjx.Action.select_from(action_idx, legal_actions)

if __name__ == "__main__":
# 4人で対局する場合は,4つのSocketIOClientで同一のサーバーに接続する.
my_agent = MyAgent() # 参加者が実装したプレイヤーをインスタンス化

sio_client = SocketIOClient(
ip='localhost',
port=5000,
namespace='/test',
query='secret',
agent=my_agent, # プレイヤーの指定
room_id=123, # 部屋のID.4人で対局させる時は,同じIDを指定する.
)
# SocketIO Client インスタンスを実行
sio_client.run()
sio_client.enter_room()
69 changes: 69 additions & 0 deletions python/learning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from torch import optim, nn, utils, Tensor
import pytorch_lightning as pl
import torch
import mjx
import mjx.agents
import numpy as np
from torch.utils.data import TensorDataset, DataLoader
from client.agent import CustomAgentBase


class MLP(pl.LightningModule):
def __init__(self, obs_size=544, n_actions=181, hidden_size=544):
super().__init__()
self.net = nn.Sequential(
nn.Linear(obs_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, n_actions),
)
self.loss_module = nn.CrossEntropyLoss()

def training_step(self, batch, batch_idx):
x, y = batch
preds = self.forward(x)
loss = self.loss_module(preds, y)
self.log("train_loss", loss)
return loss

def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=1e-3)
return optimizer

def forward(self, x):
return self.net(x.float())

inps = np.load("./obs.npy")
tgts = np.load("./actions.npy")
inps = inps.reshape(661069, 16*34)

dataset = TensorDataset(torch.Tensor(inps), torch.LongTensor(tgts))
loader = DataLoader(dataset, batch_size=2)

model = MLP()
trainer = pl.Trainer(max_epochs=1)
trainer.fit(model=model, train_dataloaders=loader)
torch.save(model.state_dict(), './model_0.pth')

class MyAgent(CustomAgentBase):

def __init__(self) -> None:
super().__init__()

def act(self, obs: mjx.Observation) -> mjx.Action:
legal_actions = obs.legal_actions()
if len(legal_actions) == 1:
return legal_actions[0]

# 予測
feature = obs.to_features(feature_name="mjx-small-v0")
with torch.no_grad():
action_logit = model(Tensor(feature.ravel()))
action_proba = torch.sigmoid(action_logit).numpy()

# アクション決定
mask = obs.action_mask()
action_idx = (mask * action_proba).argmax()
return mjx.Action.select_from(action_idx, legal_actions)

Binary file added python/model_0.pth
Binary file not shown.
Loading