{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 14wk-1: 강화학습 (2) – 4x4 Grid World (`AgentRandom`)\n",
"\n",
"최규빈 \n",
"2024-06-03\n",
"\n",
"
\n",
"\n",
"# 1. 강의영상\n",
"\n",
"\n",
"\n",
"# 2. Imports"
],
"id": "71778dd2-0d23-4415-b69c-b784d1855b8c"
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"#!pip install gymnasium\n",
"#---#\n",
"import gymnasium as gym\n",
"#---#\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib.animation import FuncAnimation\n",
"import IPython"
],
"id": "5b1cfd48-d338-4ddd-8b59-f8ab593ebda9"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 3. 4x4 Grid World\n",
"\n",
"`-` 문제설명: 4x4 그리드월드에서 상하좌우로 움직이는 에이전트가 목표점에\n",
"도달하도록 학습하는 방법\n",
"\n",
"`-` GridWorld에서 사용되는 주요변수\n",
"\n",
"1. **`State`**: 각 격자 셀이 하나의 상태이며, 에이전트는 이러한 상태 중\n",
" 하나에 있을 수 있음.\n",
"2. **`Action`**: 에이전트는 현재상태에서 다음상태로 이동하기 위해\n",
" 상,하,좌,우 중 하나의 행동을 취할 수 있음.\n",
"3. **`Reward`**: 에이전트가 현재상태에서 특정 action을 하면 얻어지는\n",
" 보상.\n",
"4. **`Terminated`**: 하나의 에피소드가 종료되었음을 나타내는 상태.\n",
"\n",
"# 4. 예비학습\n",
"\n",
"## A. `gym.spaces`\n",
"\n",
"`-` 예시1"
],
"id": "e3a4beed-ab07-464e-9a8d-584dee58e744"
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"action_space = gym.spaces.Discrete(4) \n",
"action_space "
],
"id": "86af6553-1d04-4f6a-b8ea-e91cb779e90e"
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"[action_space.sample() for _ in range(5)]"
],
"id": "365100ab-8550-4ceb-a79c-76b0537a34b3"
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"0 in action_space"
],
"id": "0a6049ab-0309-4f6b-a053-1aa56390c1c2"
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"4 in action_space"
],
"id": "cc93a879-f95e-4108-a349-d1664634205f"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`-` 예시2"
],
"id": "b30b1eab-54d8-437b-9c20-8bec6f1e218f"
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"state_space = gym.spaces.MultiDiscrete([4,4])\n",
"state_space"
],
"id": "19c4f90d-33ed-4a6f-bf36-070ac192d890"
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"scrolled": true,
"tags": []
},
"outputs": [],
"source": [
"[state_space.sample() for _ in range(5)]"
],
"id": "3f815cbc-036a-4497-965a-2727842f7ec5"
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"np.array([0,1]) in state_space"
],
"id": "202a5afa"
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"scrolled": true,
"tags": []
},
"outputs": [],
"source": [
"np.array([3,3]) in state_space"
],
"id": "04da0368-8f85-43f2-adb1-eb5dd8465806"
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"np.array([3,4]) in state_space"
],
"id": "13f530d6-d1a7-40cb-89ec-b97c27b8339c"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## B. 시각화"
],
"id": "42cd6c2e-9a53-46a1-85e6-923d69fd5b3a"
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def show(states):\n",
" fig = plt.Figure()\n",
" ax = fig.subplots()\n",
" ax.matshow(np.zeros([4,4]), cmap='bwr',alpha=0.0)\n",
" sc = ax.scatter(0, 0, color='red', s=500) \n",
" ax.text(0, 0, 'start', ha='center', va='center')\n",
" ax.text(3, 3, 'end', ha='center', va='center')\n",
" # Adding grid lines to the plot\n",
" ax.set_xticks(np.arange(-.5, 4, 1), minor=True)\n",
" ax.set_yticks(np.arange(-.5, 4, 1), minor=True)\n",
" ax.grid(which='minor', color='black', linestyle='-', linewidth=2)\n",
" state_space = gym.spaces.MultiDiscrete([4,4])\n",
" def update(t):\n",
" if states[t] in state_space:\n",
" s1,s2 = states[t]\n",
" states[t] = [s2,s1]\n",
" sc.set_offsets(states[t])\n",
" else:\n",
" s1,s2 = states[t]\n",
" s1 = s1 + 0.5 if s1 < 0 else (s1 - 0.5 if s1 > 3 else s1)\n",
" s2 = s2 + 0.5 if s2 < 0 else (s2 - 0.5 if s2 > 3 else s2)\n",
" states[t] = [s2,s1] \n",
" sc.set_offsets(states[t])\n",
" ani = FuncAnimation(fig,update,frames=len(states))\n",
" display(IPython.display.HTML(ani.to_jshtml()))"
],
"id": "0b3ec243-70a1-46e7-bd95-2bb4a5338898"
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"show([[0,0],[1,0],[2,0],[3,0],[4,0]])"
],
"id": "5379f2af-62d5-43ff-b16f-4689e7a53fe5"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 5. Env 클래스 구현"
],
"id": "e6ad9a20-9eb1-44ec-abde-402ee97b5e64"
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"action_to_direction = {\n",
" 0 : np.array([1, 0]), # row+, down\n",
" 1 : np.array([0, 1]), # col+, right\n",
" 2 : np.array([-1 ,0]), # row-, up\n",
" 3 : np.array([0, -1]) # col-, left\n",
"}\n",
"action_to_direction2 = {0: 'down', 1: 'right', 2: 'up', 3: 'left'} # 당장쓰진 않지만 하는김에 "
],
"id": "23c93695"
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"action = action_space.sample()"
],
"id": "249cf618"
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"direction = action_to_direction[action]"
],
"id": "326085f3"
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"current_state = state_space.sample()\n",
"next_state = current_state + direction\n",
"current_state, direction, next_state"
],
"id": "d36b3a0b"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`-` Class 구현: 아래와 같은 느낌의 클래스를 구현해보자."
],
"id": "41a8f9bf-5db6-490e-88d5-3986c56bafa3"
},
{
"cell_type": "code",
"execution_count": 236,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"class GridWorld:\n",
" def __init__(self):\n",
" self.state_space = gym.spaces.MultiDiscrete([4,4])\n",
" self.action_space = gym.spaces.Discrete(4) \n",
" self._action_to_direction = {\n",
" 0 : np.array([1, 0]), # row+, down\n",
" 1 : np.array([0, 1]), # col+, right\n",
" 2 : np.array([-1 ,0]), # row-, up\n",
" 3 : np.array([0, -1]) # col-, left\n",
" }\n",
" self.reset()\n",
" self.state = None \n",
" self.reward = None \n",
" self.termiated = None\n",
" def step(self,action):\n",
" direction = self._action_to_direction[action]\n",
" self.state = self.state + direction\n",
" if np.array_equal(self.state,np.array([3,3])): \n",
" self.reward = 100 \n",
" self.terminated = True\n",
" elif self.state not in self.state_space:\n",
" self.reward = -10\n",
" self.terminated = True\n",
" else:\n",
" self.reward = -1 \n",
" return self.state, self.reward, self.terminated\n",
" def reset(self):\n",
" self.state = np.array([0,0])\n",
" self.terminated = False \n",
" return self.state "
],
"id": "e89495dc"
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
"env = GridWorld()\n",
"state = env.reset()\n",
"states = [] \n",
"states.append(state)\n",
"for t in range(50):\n",
" action = env.action_space.sample() \n",
" state,reward,terminated = env.step(action)\n",
" states.append(state)\n",
" if terminated: break "
],
"id": "a2358dfa"
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"show(states)"
],
"id": "886e5543-619a-4488-a39f-77bc8a2fa254"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- 처음에 바로 죽는 경우가 많아 몇번 시도하고 위의 애니메이션을 얻음\n",
"\n",
"# 6. `AgentRandom`\n",
"\n",
"## A. 에이전트 클래스 설계\n",
"\n",
"`-` 우리가 구현하고 싶은 기능\n",
"\n",
"- `.act()`: 액션을 결정 –\\> 여기서는 그냥 랜덤액션\n",
"- `.save_experience()`: 데이터를 저장 –\\> 여기에 일단 초점을 맞추자\n",
"- `.learn()`: 데이터로에서 학습 –\\> 패스"
],
"id": "4e5b3f92-2dd7-4dc2-81df-a8e94ae68ce7"
},
{
"cell_type": "code",
"execution_count": 143,
"metadata": {},
"outputs": [],
"source": [
"class AgentRandom: \n",
" def __init__(self,env):\n",
" #--# define spaces \n",
" self.action_space = env.action_space\n",
" self.state_space = env.state_space\n",
" #--# replay buffer \n",
" self.action = None \n",
" self.actions = [] \n",
" self.current_state = None \n",
" self.current_states = [] \n",
" self.reward = None \n",
" self.rewards = [] \n",
" self.next_state = None \n",
" self.next_states = [] \n",
" self.terminated = None \n",
" self.terminations = []\n",
" #--# other information\n",
" self.n_episodes = 0 \n",
" self.n_experiences = 0\n",
" self.score = 0 \n",
" self.playtimes = [] \n",
" self.scores = [] \n",
" def act(self):\n",
" self.action = self.action_space.sample()\n",
" def learn(self):\n",
" pass \n",
" def save_experience(self):\n",
" self.current_states.append(self.current_state) \n",
" self.actions.append(self.action)\n",
" self.rewards.append(self.reward) \n",
" self.next_states.append(self.next_state)\n",
" self.terminations.append(self.terminated)\n",
" #--#\n",
" self.n_experiences = self.n_experiences + 1 \n",
" self.score = self.score + self.reward"
],
"id": "b6a2a93a"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## B. 환경과 상호작용"
],
"id": "3927789c-0b99-4969-a79e-20a9472d752d"
},
{
"cell_type": "code",
"execution_count": 140,
"metadata": {},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"에피소드: 1 점수(에피소드): -12 게임시간(에피소드): 3 경험수: 3\n",
"에피소드: 2 점수(에피소드): -15 게임시간(에피소드): 6 경험수: 9\n",
"에피소드: 3 점수(에피소드): -11 게임시간(에피소드): 2 경험수: 11\n",
"에피소드: 4 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 12\n",
"에피소드: 5 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 13\n",
"에피소드: 6 점수(에피소드): -12 게임시간(에피소드): 3 경험수: 16\n",
"에피소드: 7 점수(에피소드): -11 게임시간(에피소드): 2 경험수: 18\n",
"에피소드: 8 점수(에피소드): -18 게임시간(에피소드): 9 경험수: 27\n",
"에피소드: 9 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 28\n",
"에피소드: 10 점수(에피소드): 91 게임시간(에피소드): 10 경험수: 38\n",
"에피소드: 11 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 39\n",
"에피소드: 12 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 40\n",
"에피소드: 13 점수(에피소드): -11 게임시간(에피소드): 2 경험수: 42\n",
"에피소드: 14 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 43\n",
"에피소드: 15 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 44\n",
"에피소드: 16 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 45\n",
"에피소드: 17 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 46\n",
"에피소드: 18 점수(에피소드): -15 게임시간(에피소드): 6 경험수: 52\n",
"에피소드: 19 점수(에피소드): -11 게임시간(에피소드): 2 경험수: 54\n",
"에피소드: 20 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 55\n",
"에피소드: 21 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 56\n",
"에피소드: 22 점수(에피소드): -12 게임시간(에피소드): 3 경험수: 59\n",
"에피소드: 23 점수(에피소드): -11 게임시간(에피소드): 2 경험수: 61\n",
"에피소드: 24 점수(에피소드): -11 게임시간(에피소드): 2 경험수: 63\n",
"에피소드: 25 점수(에피소드): -13 게임시간(에피소드): 4 경험수: 67\n",
"에피소드: 26 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 68\n",
"에피소드: 27 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 69\n",
"에피소드: 28 점수(에피소드): -11 게임시간(에피소드): 2 경험수: 71\n",
"에피소드: 29 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 72\n",
"에피소드: 30 점수(에피소드): -13 게임시간(에피소드): 4 경험수: 76\n",
"에피소드: 31 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 77\n",
"에피소드: 32 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 78\n",
"에피소드: 33 점수(에피소드): -18 게임시간(에피소드): 9 경험수: 87\n",
"에피소드: 34 점수(에피소드): -13 게임시간(에피소드): 4 경험수: 91\n",
"에피소드: 35 점수(에피소드): -18 게임시간(에피소드): 9 경험수: 100\n",
"에피소드: 36 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 101\n",
"에피소드: 37 점수(에피소드): -15 게임시간(에피소드): 6 경험수: 107\n",
"에피소드: 38 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 108\n",
"에피소드: 39 점수(에피소드): -13 게임시간(에피소드): 4 경험수: 112\n",
"에피소드: 40 점수(에피소드): -17 게임시간(에피소드): 8 경험수: 120\n",
"에피소드: 41 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 121\n",
"에피소드: 42 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 122\n",
"에피소드: 43 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 123\n",
"에피소드: 44 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 124\n",
"에피소드: 45 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 125\n",
"에피소드: 46 점수(에피소드): -17 게임시간(에피소드): 8 경험수: 133\n",
"에피소드: 47 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 134\n",
"에피소드: 48 점수(에피소드): -13 게임시간(에피소드): 4 경험수: 138\n",
"에피소드: 49 점수(에피소드): -12 게임시간(에피소드): 3 경험수: 141\n",
"에피소드: 50 점수(에피소드): -11 게임시간(에피소드): 2 경험수: 143"
]
}
],
"source": [
"env = GridWorld()\n",
"agent = AgentRandom(env)\n",
"#--#\n",
"for _ in range(50):\n",
" agent.current_state = env.reset()\n",
" agent.score = 0 \n",
" for t in range(100):\n",
" # step1: 행동\n",
" agent.act()\n",
" # step2: 보상\n",
" agent.next_state, agent.reward, agent.terminated = env.step(agent.action)\n",
" # step3: 저장 & 학습\n",
" agent.save_experience()\n",
" agent.learn()\n",
" # step4: \n",
" agent.current_state = agent.next_state\n",
" if agent.terminated: break\n",
" agent.scores.append(agent.score) \n",
" agent.playtimes.append(t+1)\n",
" agent.n_episodes = agent.n_episodes + 1 \n",
" #---#\n",
" print(\n",
" f\"에피소드: {agent.n_episodes} \\t\"\n",
" f\"점수(에피소드): {agent.scores[-1]} \\t\" \n",
" f\"게임시간(에피소드): {agent.playtimes[-1]}\\t\"\n",
" f\"경험수: {agent.n_experiences}\"\n",
" )"
],
"id": "50846003"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## C. 상호작용결과 시각화"
],
"id": "50c4dc0f-df3c-4955-9b8f-897a67132b19"
},
{
"cell_type": "code",
"execution_count": 141,
"metadata": {},
"outputs": [],
"source": [
"[np.array([0,0])] + agent.next_states[28:38] # 에피소드10"
],
"id": "c5604215"
},
{
"cell_type": "code",
"execution_count": 142,
"metadata": {},
"outputs": [],
"source": [
"show([np.array([0,0])] + agent.next_states[28:38]) # 에피소드5"
],
"id": "9b695ca7"
}
],
"nbformat": 4,
"nbformat_minor": 5,
"metadata": {
"kernelspec": {
"name": "python3",
"display_name": "Python 3 (ipykernel)",
"language": "python"
},
"language_info": {
"name": "python",
"codemirror_mode": {
"name": "ipython",
"version": "3"
},
"file_extension": ".py",
"mimetype": "text/x-python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
}
}
}