import gymnasium as gym
#---#
import numpy as np
import collections
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import IPython14wk-1: (강화학습) – 4x4 Grid World 환경의 이해
1. 강의영상
2. Imports
3. 지난시간 코드 복습
- 클래스선언
- 수정사항: (1) deque의 maxlen을 500000 으로 조정 (2) print하는 코드를 주석처리
class GridWorld:
def __init__(self):
self.a2d = {
0: np.array([0,1]), # →
1: np.array([0,-1]), # ←
2: np.array([1,0]), # ↓
3: np.array([-1,0]) # ↑
}
self.state_space = gym.spaces.MultiDiscrete([4,4])
self.state = np.array([0,0])
self.reward = None
self.terminated = False
def step(self,action):
self.state = self.state + self.a2d[action]
s1,s2 = self.state
if (s1==3) and (s2==3):
self.reward = 100
self.terminated = True
elif self.state in self.state_space:
self.reward = -1
self.terminated = False
else:
self.reward = -10
self.terminated = True
# print(
# f"action = {action}\t"
# f"state = {self.state - self.a2d[action]} -> {self.state}\t"
# f"reward = {self.reward}\t"
# f"termiated = {self.terminated}"
# )
return self.state, self.reward, self.terminated
def reset(self):
self.state = np.array([0,0])
self.terminated = False
return self.state
class RandomAgent:
def __init__(self):
self.state = np.array([0,0])
self.action = None
self.reward = None
self.next_state = None
self.terminated = None
#---#
self.states = collections.deque(maxlen=500000)
self.actions = collections.deque(maxlen=500000)
self.rewards = collections.deque(maxlen=500000)
self.next_states = collections.deque(maxlen=500000)
self.terminations = collections.deque(maxlen=500000)
#---#
self.action_space = gym.spaces.Discrete(4)
self.n_experience = 0
def act(self):
self.action = self.action_space.sample()
def save_experience(self):
self.states.append(self.state)
self.actions.append(self.action)
self.rewards.append(self.reward)
self.next_states.append(self.next_state)
self.terminations.append(self.terminated)
self.n_experience = self.n_experience + 1
def learn(self):
pass- 메인코드
- 수정사항: 가독성을 위해 에피소드가 진행되는 for문의 구조를 수정함 (특히 step4)
마음에 들지 않지만 꼭 외워야 하는것
env.step은 항상 next_state, reward, terminated, truncated, info 를 리턴한다. – 짐나지엄 라이브러리 규격때문env.reset은 환경을 초기화할 뿐만 아니라, state, info를 반환하는 기능도 있다. – 짐나지엄 라이브러리 규격때문player는 항상state와next_state를 구분해서 저장한다. (다른변수들은 그렇지 않음) 이는 강화학습이 MDP(마코프체인+행동+보상)구조를 따르게 때문에 생기는 고유한 특징이다. – 이론적이 이유
- 환경과 에이전트의 상호작용 이해를 위한 다이어그램:
4. GridWorld 환경의 이해
A. 데이터 축적
- 랜덤에이전트를 이용해 무작위로 100,000 에피소드를 진행해보자.
player = RandomAgent()
env = GridWorld()
scores = []
score = 0
#
for e in range(1,100000):
#---에피소드시작---#
while True:
# step1 -- 액션선택
player.act()
# step2 -- 환경반응
player.next_state, player.reward, player.terminated = env.step(player.action)
# step3 -- 경험기록 & 학습
player.save_experience()
player.learn()
# step4 --종료 조건 체크 & 후속 처리
if env.terminated:
score = score + player.reward
scores.append(score)
score = 0
player.state = env.reset()
break
else:
score = score + player.reward
player.state = player.next_state강의노트 수정 2025-06-12
노규호학생의 도움으로 예전강의의 오류를 발견하여 수정하였습니다.
# 수정전
...
if env.terminated:
...
else:
score = score + player.reward
scores.append(score)
player.state = player.next_state
# 수정후
if env.terminated:
...
else:
score = score + player.reward
# scores.append(score) ### <--- 여기를 주석처리해야함!!
player.state = player.next_state
:::
::: {#ef29fdf9-b027-4d88-8bd7-1e5c6d3b08d9 .cell tags='[]' execution_count=82}
``` {.python .cell-code}
player.n_experience325268
B. 첫번째 q_table
- 밴딧게임에서는 \(q(a)\) 를 정의했었음.
- \(q(0) = 1\)
- \(q(1) = 10\)
- 여기에서는 \(q(s_1,s_2,a)\)를 정의해야함!
직관적으로 아래의 그림이 떠오름 
그림에 대응하는 \(q(s_1,s_2,a)\)의 값은 아래와 같음
\(a=0 \Leftrightarrow \text{\tt action=right}\)
\[ \begin{bmatrix} q(0,0,0) & q(0,1,0) & q(0,2,0) & q(0,3,0) \\ q(1,0,0) & q(1,1,0) & q(1,2,0) & q(1,3,0) \\ q(2,0,0) & q(2,1,0) & q(2,2,0) & q(2,3,0) \\ q(3,0,0) & q(3,1,0) &q(3,2,0) & q(3,3,0) \\ \end{bmatrix} = \begin{bmatrix} -1 & -1 & -1 & -10 \\ -1 & -1 & -1 & -10 \\ -1 & -1 & -1 & -10 \\ -1 & -1 & 100 & \text{-} \\ \end{bmatrix} \]
\(a=1 \Leftrightarrow \text{\tt action=left}\)
\[ \begin{bmatrix} q(0,0,1) & q(0,1,1) & q(0,2,1) & q(0,3,1) \\ q(1,0,1) & q(1,1,1) & q(1,2,1) & q(1,3,1) \\ q(2,0,1) & q(2,1,1) & q(2,2,1) & q(2,3,1) \\ q(3,0,1) & q(3,1,1) &q(3,2,1) & q(3,3,1) \\ \end{bmatrix} = \begin{bmatrix} -10 & -1 & -1 & -1 \\ -10& -1 & -1 & -1 \\ -10 & -1 & -1 & -1 \\ -10 & -1 & -1 & \text{-} \\ \end{bmatrix} \]
\(a=2 \Leftrightarrow \text{\tt action=down}\)
\[ \begin{bmatrix} q(0,0,2) & q(0,1,2) & q(0,2,2) & q(0,3,2) \\ q(1,0,2) & q(1,1,2) & q(1,2,2) & q(1,3,2) \\ q(2,0,2) & q(2,1,2) & q(2,2,2) & q(2,3,2) \\ q(3,0,2) & q(3,1,2) &q(3,2,2) & q(3,3,2) \\ \end{bmatrix} = \begin{bmatrix} -1 & -1 & -1 & -1 \\ -1& -1 & -1 & -1 \\ -1 & -1 & -1 & 100\\ -10 & -10 & -10 & \text{-} \\ \end{bmatrix} \]
\(a=3 \Leftrightarrow \text{\tt action=up}\)
\[ \begin{bmatrix} q(0,0,3) & q(0,1,3) & q(0,2,3) & q(0,3,3) \\ q(1,0,3) & q(1,1,3) & q(1,2,3) & q(1,3,3) \\ q(2,0,3) & q(2,1,3) & q(2,2,3) & q(2,3,3) \\ q(3,0,3) & q(3,1,3) &q(3,2,3) & q(3,3,3) \\ \end{bmatrix} =\begin{bmatrix} -10 & -10 & -10 & -10\\ -1& -1 & -1 & -1 \\ -1 & -1 & -1 & -1 \\ -1 & -1 & -1 & \text{-} \\ \end{bmatrix} \]
- 데이터를 바탕으로 \(q(s_1,s_2,a)\)를 구해보자.
player.states[0], player.actions[0], player.rewards[0](array([0, 0]), 0, -1)
q_table = np.zeros((4,4,4))
count = np.zeros((4,4,4))memory = zip(player.states, player.actions, player.rewards)
for (s1,s2), a, r in memory:
q_table[s1,s2,a] = q_table[s1,s2,a] + r
count[s1,s2,a] = count[s1,s2,a] + 1 count[count==0] = 0.001 q_table = q_table / countq_table[...,0], q_table[...,1], q_table[...,2], q_table[...,3](array([[ -1., -1., -1., -10.],
[ -1., -1., -1., -10.],
[ -1., -1., -1., -10.],
[ -1., -1., 100., 0.]]),
array([[-10., -1., -1., -1.],
[-10., -1., -1., -1.],
[-10., -1., -1., -1.],
[-10., -1., -1., 0.]]),
array([[ -1., -1., -1., -1.],
[ -1., -1., -1., -1.],
[ -1., -1., -1., 100.],
[-10., -10., -10., 0.]]),
array([[-10., -10., -10., -10.],
[ -1., -1., -1., -1.],
[ -1., -1., -1., -1.],
[ -1., -1., -1., 0.]]))
- count를 사용하지 않는 방법은 없을까? – 테크닉
q_table = np.zeros((4,4,4))
memory = zip(player.states, player.actions, player.rewards)
for (s1,s2), a, r in memory:
qhat = q_table[s1,s2,a] # 내가 생각했던갓
q = r # 실제값
diff = q-qhat # 차이
q_table[s1,s2,a] = q_table[s1,s2,a] + 0.01*diff# updateq_table.round(2)array([[[ -1. , -10. , -1. , -10. ],
[ -1. , -1. , -1. , -10. ],
[ -1. , -1. , -1. , -10. ],
[-10. , -1. , -1. , -10. ]],
[[ -1. , -10. , -1. , -1. ],
[ -1. , -1. , -1. , -1. ],
[ -1. , -1. , -1. , -1. ],
[-10. , -1. , -1. , -1. ]],
[[ -1. , -10. , -1. , -1. ],
[ -1. , -1. , -1. , -1. ],
[ -1. , -1. , -1. , -1. ],
[-10. , -1. , 99.99, -1. ]],
[[ -1. , -10. , -10. , -1. ],
[ -1. , -1. , -10. , -1. ],
[ 99.99, -1. , -10. , -1. ],
[ 0. , 0. , 0. , 0. ]]])
C. 첫번째 q_table보다 나은 것?
- 첫번째 q_table을 알고있다고 가정하자.

- 정책시각화 (합리적인 행동)

- 이게 최선의 정책일까?