{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 14wk-1: 강화학습 (2) – 4x4 Grid World (`AgentRandom`)\n", "\n", "최규빈 \n", "2024-06-03\n", "\n", "\n", "\n", "# 1. 강의영상\n", "\n", "\n", "\n", "# 2. Imports" ], "id": "71778dd2-0d23-4415-b69c-b784d1855b8c" }, { "cell_type": "code", "execution_count": 2, "metadata": { "tags": [] }, "outputs": [], "source": [ "#!pip install gymnasium\n", "#---#\n", "import gymnasium as gym\n", "#---#\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from matplotlib.animation import FuncAnimation\n", "import IPython" ], "id": "5b1cfd48-d338-4ddd-8b59-f8ab593ebda9" }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 3. 4x4 Grid World\n", "\n", "`-` 문제설명: 4x4 그리드월드에서 상하좌우로 움직이는 에이전트가 목표점에\n", "도달하도록 학습하는 방법\n", "\n", "`-` GridWorld에서 사용되는 주요변수\n", "\n", "1. **`State`**: 각 격자 셀이 하나의 상태이며, 에이전트는 이러한 상태 중\n", " 하나에 있을 수 있음.\n", "2. **`Action`**: 에이전트는 현재상태에서 다음상태로 이동하기 위해\n", " 상,하,좌,우 중 하나의 행동을 취할 수 있음.\n", "3. **`Reward`**: 에이전트가 현재상태에서 특정 action을 하면 얻어지는\n", " 보상.\n", "4. **`Terminated`**: 하나의 에피소드가 종료되었음을 나타내는 상태.\n", "\n", "# 4. 예비학습\n", "\n", "## A. `gym.spaces`\n", "\n", "`-` 예시1" ], "id": "e3a4beed-ab07-464e-9a8d-584dee58e744" }, { "cell_type": "code", "execution_count": 3, "metadata": { "tags": [] }, "outputs": [], "source": [ "action_space = gym.spaces.Discrete(4) \n", "action_space " ], "id": "86af6553-1d04-4f6a-b8ea-e91cb779e90e" }, { "cell_type": "code", "execution_count": 4, "metadata": { "tags": [] }, "outputs": [], "source": [ "[action_space.sample() for _ in range(5)]" ], "id": "365100ab-8550-4ceb-a79c-76b0537a34b3" }, { "cell_type": "code", "execution_count": 22, "metadata": { "tags": [] }, "outputs": [], "source": [ "0 in action_space" ], "id": "0a6049ab-0309-4f6b-a053-1aa56390c1c2" }, { "cell_type": "code", "execution_count": 23, "metadata": { "tags": [] }, "outputs": [], "source": [ "4 in action_space" ], "id": "cc93a879-f95e-4108-a349-d1664634205f" }, { "cell_type": "markdown", "metadata": {}, "source": [ "`-` 예시2" ], "id": "b30b1eab-54d8-437b-9c20-8bec6f1e218f" }, { "cell_type": "code", "execution_count": 24, "metadata": { "tags": [] }, "outputs": [], "source": [ "state_space = gym.spaces.MultiDiscrete([4,4])\n", "state_space" ], "id": "19c4f90d-33ed-4a6f-bf36-070ac192d890" }, { "cell_type": "code", "execution_count": 25, "metadata": { "scrolled": true, "tags": [] }, "outputs": [], "source": [ "[state_space.sample() for _ in range(5)]" ], "id": "3f815cbc-036a-4497-965a-2727842f7ec5" }, { "cell_type": "code", "execution_count": 26, "metadata": { "scrolled": true }, "outputs": [], "source": [ "np.array([0,1]) in state_space" ], "id": "202a5afa" }, { "cell_type": "code", "execution_count": 27, "metadata": { "scrolled": true, "tags": [] }, "outputs": [], "source": [ "np.array([3,3]) in state_space" ], "id": "04da0368-8f85-43f2-adb1-eb5dd8465806" }, { "cell_type": "code", "execution_count": 28, "metadata": { "tags": [] }, "outputs": [], "source": [ "np.array([3,4]) in state_space" ], "id": "13f530d6-d1a7-40cb-89ec-b97c27b8339c" }, { "cell_type": "markdown", "metadata": {}, "source": [ "## B. 시각화" ], "id": "42cd6c2e-9a53-46a1-85e6-923d69fd5b3a" }, { "cell_type": "code", "execution_count": 29, "metadata": { "tags": [] }, "outputs": [], "source": [ "def show(states):\n", " fig = plt.Figure()\n", " ax = fig.subplots()\n", " ax.matshow(np.zeros([4,4]), cmap='bwr',alpha=0.0)\n", " sc = ax.scatter(0, 0, color='red', s=500) \n", " ax.text(0, 0, 'start', ha='center', va='center')\n", " ax.text(3, 3, 'end', ha='center', va='center')\n", " # Adding grid lines to the plot\n", " ax.set_xticks(np.arange(-.5, 4, 1), minor=True)\n", " ax.set_yticks(np.arange(-.5, 4, 1), minor=True)\n", " ax.grid(which='minor', color='black', linestyle='-', linewidth=2)\n", " state_space = gym.spaces.MultiDiscrete([4,4])\n", " def update(t):\n", " if states[t] in state_space:\n", " s1,s2 = states[t]\n", " states[t] = [s2,s1]\n", " sc.set_offsets(states[t])\n", " else:\n", " s1,s2 = states[t]\n", " s1 = s1 + 0.5 if s1 < 0 else (s1 - 0.5 if s1 > 3 else s1)\n", " s2 = s2 + 0.5 if s2 < 0 else (s2 - 0.5 if s2 > 3 else s2)\n", " states[t] = [s2,s1] \n", " sc.set_offsets(states[t])\n", " ani = FuncAnimation(fig,update,frames=len(states))\n", " display(IPython.display.HTML(ani.to_jshtml()))" ], "id": "0b3ec243-70a1-46e7-bd95-2bb4a5338898" }, { "cell_type": "code", "execution_count": 30, "metadata": { "tags": [] }, "outputs": [], "source": [ "show([[0,0],[1,0],[2,0],[3,0],[4,0]])" ], "id": "5379f2af-62d5-43ff-b16f-4689e7a53fe5" }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 5. Env 클래스 구현" ], "id": "e6ad9a20-9eb1-44ec-abde-402ee97b5e64" }, { "cell_type": "code", "execution_count": 31, "metadata": { "tags": [] }, "outputs": [], "source": [ "action_to_direction = {\n", " 0 : np.array([1, 0]), # row+, down\n", " 1 : np.array([0, 1]), # col+, right\n", " 2 : np.array([-1 ,0]), # row-, up\n", " 3 : np.array([0, -1]) # col-, left\n", "}\n", "action_to_direction2 = {0: 'down', 1: 'right', 2: 'up', 3: 'left'} # 당장쓰진 않지만 하는김에 " ], "id": "23c93695" }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "action = action_space.sample()" ], "id": "249cf618" }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "direction = action_to_direction[action]" ], "id": "326085f3" }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "current_state = state_space.sample()\n", "next_state = current_state + direction\n", "current_state, direction, next_state" ], "id": "d36b3a0b" }, { "cell_type": "markdown", "metadata": {}, "source": [ "`-` Class 구현: 아래와 같은 느낌의 클래스를 구현해보자." ], "id": "41a8f9bf-5db6-490e-88d5-3986c56bafa3" }, { "cell_type": "code", "execution_count": 236, "metadata": { "tags": [] }, "outputs": [], "source": [ "class GridWorld:\n", " def __init__(self):\n", " self.state_space = gym.spaces.MultiDiscrete([4,4])\n", " self.action_space = gym.spaces.Discrete(4) \n", " self._action_to_direction = {\n", " 0 : np.array([1, 0]), # row+, down\n", " 1 : np.array([0, 1]), # col+, right\n", " 2 : np.array([-1 ,0]), # row-, up\n", " 3 : np.array([0, -1]) # col-, left\n", " }\n", " self.reset()\n", " self.state = None \n", " self.reward = None \n", " self.termiated = None\n", " def step(self,action):\n", " direction = self._action_to_direction[action]\n", " self.state = self.state + direction\n", " if np.array_equal(self.state,np.array([3,3])): \n", " self.reward = 100 \n", " self.terminated = True\n", " elif self.state not in self.state_space:\n", " self.reward = -10\n", " self.terminated = True\n", " else:\n", " self.reward = -1 \n", " return self.state, self.reward, self.terminated\n", " def reset(self):\n", " self.state = np.array([0,0])\n", " self.terminated = False \n", " return self.state " ], "id": "e89495dc" }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [], "source": [ "env = GridWorld()\n", "state = env.reset()\n", "states = [] \n", "states.append(state)\n", "for t in range(50):\n", " action = env.action_space.sample() \n", " state,reward,terminated = env.step(action)\n", " states.append(state)\n", " if terminated: break " ], "id": "a2358dfa" }, { "cell_type": "code", "execution_count": 43, "metadata": { "tags": [] }, "outputs": [], "source": [ "show(states)" ], "id": "886e5543-619a-4488-a39f-77bc8a2fa254" }, { "cell_type": "markdown", "metadata": {}, "source": [ "- 처음에 바로 죽는 경우가 많아 몇번 시도하고 위의 애니메이션을 얻음\n", "\n", "# 6. `AgentRandom`\n", "\n", "## A. 에이전트 클래스 설계\n", "\n", "`-` 우리가 구현하고 싶은 기능\n", "\n", "- `.act()`: 액션을 결정 –\\> 여기서는 그냥 랜덤액션\n", "- `.save_experience()`: 데이터를 저장 –\\> 여기에 일단 초점을 맞추자\n", "- `.learn()`: 데이터로에서 학습 –\\> 패스" ], "id": "4e5b3f92-2dd7-4dc2-81df-a8e94ae68ce7" }, { "cell_type": "code", "execution_count": 143, "metadata": {}, "outputs": [], "source": [ "class AgentRandom: \n", " def __init__(self,env):\n", " #--# define spaces \n", " self.action_space = env.action_space\n", " self.state_space = env.state_space\n", " #--# replay buffer \n", " self.action = None \n", " self.actions = [] \n", " self.current_state = None \n", " self.current_states = [] \n", " self.reward = None \n", " self.rewards = [] \n", " self.next_state = None \n", " self.next_states = [] \n", " self.terminated = None \n", " self.terminations = []\n", " #--# other information\n", " self.n_episodes = 0 \n", " self.n_experiences = 0\n", " self.score = 0 \n", " self.playtimes = [] \n", " self.scores = [] \n", " def act(self):\n", " self.action = self.action_space.sample()\n", " def learn(self):\n", " pass \n", " def save_experience(self):\n", " self.current_states.append(self.current_state) \n", " self.actions.append(self.action)\n", " self.rewards.append(self.reward) \n", " self.next_states.append(self.next_state)\n", " self.terminations.append(self.terminated)\n", " #--#\n", " self.n_experiences = self.n_experiences + 1 \n", " self.score = self.score + self.reward" ], "id": "b6a2a93a" }, { "cell_type": "markdown", "metadata": {}, "source": [ "## B. 환경과 상호작용" ], "id": "3927789c-0b99-4969-a79e-20a9472d752d" }, { "cell_type": "code", "execution_count": 140, "metadata": {}, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "에피소드: 1 점수(에피소드): -12 게임시간(에피소드): 3 경험수: 3\n", "에피소드: 2 점수(에피소드): -15 게임시간(에피소드): 6 경험수: 9\n", "에피소드: 3 점수(에피소드): -11 게임시간(에피소드): 2 경험수: 11\n", "에피소드: 4 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 12\n", "에피소드: 5 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 13\n", "에피소드: 6 점수(에피소드): -12 게임시간(에피소드): 3 경험수: 16\n", "에피소드: 7 점수(에피소드): -11 게임시간(에피소드): 2 경험수: 18\n", "에피소드: 8 점수(에피소드): -18 게임시간(에피소드): 9 경험수: 27\n", "에피소드: 9 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 28\n", "에피소드: 10 점수(에피소드): 91 게임시간(에피소드): 10 경험수: 38\n", "에피소드: 11 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 39\n", "에피소드: 12 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 40\n", "에피소드: 13 점수(에피소드): -11 게임시간(에피소드): 2 경험수: 42\n", "에피소드: 14 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 43\n", "에피소드: 15 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 44\n", "에피소드: 16 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 45\n", "에피소드: 17 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 46\n", "에피소드: 18 점수(에피소드): -15 게임시간(에피소드): 6 경험수: 52\n", "에피소드: 19 점수(에피소드): -11 게임시간(에피소드): 2 경험수: 54\n", "에피소드: 20 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 55\n", "에피소드: 21 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 56\n", "에피소드: 22 점수(에피소드): -12 게임시간(에피소드): 3 경험수: 59\n", "에피소드: 23 점수(에피소드): -11 게임시간(에피소드): 2 경험수: 61\n", "에피소드: 24 점수(에피소드): -11 게임시간(에피소드): 2 경험수: 63\n", "에피소드: 25 점수(에피소드): -13 게임시간(에피소드): 4 경험수: 67\n", "에피소드: 26 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 68\n", "에피소드: 27 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 69\n", "에피소드: 28 점수(에피소드): -11 게임시간(에피소드): 2 경험수: 71\n", "에피소드: 29 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 72\n", "에피소드: 30 점수(에피소드): -13 게임시간(에피소드): 4 경험수: 76\n", "에피소드: 31 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 77\n", "에피소드: 32 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 78\n", "에피소드: 33 점수(에피소드): -18 게임시간(에피소드): 9 경험수: 87\n", "에피소드: 34 점수(에피소드): -13 게임시간(에피소드): 4 경험수: 91\n", "에피소드: 35 점수(에피소드): -18 게임시간(에피소드): 9 경험수: 100\n", "에피소드: 36 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 101\n", "에피소드: 37 점수(에피소드): -15 게임시간(에피소드): 6 경험수: 107\n", "에피소드: 38 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 108\n", "에피소드: 39 점수(에피소드): -13 게임시간(에피소드): 4 경험수: 112\n", "에피소드: 40 점수(에피소드): -17 게임시간(에피소드): 8 경험수: 120\n", "에피소드: 41 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 121\n", "에피소드: 42 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 122\n", "에피소드: 43 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 123\n", "에피소드: 44 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 124\n", "에피소드: 45 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 125\n", "에피소드: 46 점수(에피소드): -17 게임시간(에피소드): 8 경험수: 133\n", "에피소드: 47 점수(에피소드): -10 게임시간(에피소드): 1 경험수: 134\n", "에피소드: 48 점수(에피소드): -13 게임시간(에피소드): 4 경험수: 138\n", "에피소드: 49 점수(에피소드): -12 게임시간(에피소드): 3 경험수: 141\n", "에피소드: 50 점수(에피소드): -11 게임시간(에피소드): 2 경험수: 143" ] } ], "source": [ "env = GridWorld()\n", "agent = AgentRandom(env)\n", "#--#\n", "for _ in range(50):\n", " agent.current_state = env.reset()\n", " agent.score = 0 \n", " for t in range(100):\n", " # step1: 행동\n", " agent.act()\n", " # step2: 보상\n", " agent.next_state, agent.reward, agent.terminated = env.step(agent.action)\n", " # step3: 저장 & 학습\n", " agent.save_experience()\n", " agent.learn()\n", " # step4: \n", " agent.current_state = agent.next_state\n", " if agent.terminated: break\n", " agent.scores.append(agent.score) \n", " agent.playtimes.append(t+1)\n", " agent.n_episodes = agent.n_episodes + 1 \n", " #---#\n", " print(\n", " f\"에피소드: {agent.n_episodes} \\t\"\n", " f\"점수(에피소드): {agent.scores[-1]} \\t\" \n", " f\"게임시간(에피소드): {agent.playtimes[-1]}\\t\"\n", " f\"경험수: {agent.n_experiences}\"\n", " )" ], "id": "50846003" }, { "cell_type": "markdown", "metadata": {}, "source": [ "## C. 상호작용결과 시각화" ], "id": "50c4dc0f-df3c-4955-9b8f-897a67132b19" }, { "cell_type": "code", "execution_count": 141, "metadata": {}, "outputs": [], "source": [ "[np.array([0,0])] + agent.next_states[28:38] # 에피소드10" ], "id": "c5604215" }, { "cell_type": "code", "execution_count": 142, "metadata": {}, "outputs": [], "source": [ "show([np.array([0,0])] + agent.next_states[28:38]) # 에피소드5" ], "id": "9b695ca7" } ], "nbformat": 4, "nbformat_minor": 5, "metadata": { "kernelspec": { "name": "python3", "display_name": "Python 3 (ipykernel)", "language": "python" }, "language_info": { "name": "python", "codemirror_mode": { "name": "ipython", "version": "3" }, "file_extension": ".py", "mimetype": "text/x-python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.8" } } }