加入星計(jì)劃,您可以享受以下權(quán)益:

  • 創(chuàng)作內(nèi)容快速變現(xiàn)
  • 行業(yè)影響力擴(kuò)散
  • 作品版權(quán)保護(hù)
  • 300W+ 專業(yè)用戶
  • 1.5W+ 優(yōu)質(zhì)創(chuàng)作者
  • 5000+ 長(zhǎng)期合作伙伴
立即加入
  • 正文
    • 安裝和導(dǎo)入所需的庫(kù)和環(huán)境
    • Q網(wǎng)絡(luò)搭建
    • 經(jīng)驗(yàn)回放實(shí)現(xiàn)
    • DQNAgent實(shí)現(xiàn)
    • 訓(xùn)練
  • 推薦器件
  • 相關(guān)推薦
  • 電子產(chǎn)業(yè)圖譜
申請(qǐng)入駐 產(chǎn)業(yè)圖譜

基于DQN和TensorFlow的LunarLander實(shí)現(xiàn)(全代碼)

01/30 13:52
2811
閱讀需 19 分鐘
加入交流群
掃碼加入
獲取工程師必備禮包
參與熱點(diǎn)資訊討論

使用深度Q網(wǎng)絡(luò)(Deep Q-Network, DQN)來訓(xùn)練一個(gè)在openai-gym的LunarLander-v2環(huán)境中的強(qiáng)化學(xué)習(xí)agent,讓小火箭成功著陸。

下面代碼直接扔到j(luò)upyter notebook或CoLab上就能跑起來。

安裝和導(dǎo)入所需的庫(kù)和環(huán)境

安裝和設(shè)置所需的庫(kù)和環(huán)境,使其能夠在Jupyter Notebook中運(yùn)行。

!pip install gym
!apt-get install xvfb -y
!pip install pyvirtualdisplay #用于在沒有顯示器的環(huán)境中創(chuàng)建虛擬顯示
!pip install Pillow #一個(gè)圖像處理庫(kù)
!pip install swig
!pip install "gym[box2d]"

創(chuàng)建并啟動(dòng)一個(gè)虛擬顯示,在沒有圖形界面的服務(wù)器上運(yùn)行強(qiáng)化學(xué)習(xí)環(huán)境:

from pyvirtualdisplay import Display
display = Display(visible=0, size=(1400, 900))
display.start()

引入所需庫(kù):

import gym
import time
import tqdm
import numpy as np
from IPython import display as ipydisplay
from PIL import Image

創(chuàng)建一個(gè)LunarLander-v2環(huán)境的DQN代理:

agent = DQNAgent('LunarLander-v2')

total_score, records = agent.simulate(visualize=True)
print(f'Total score {total_score:.2f}')
record_list = []
for i in tqdm.tqdm(range(100)):
total_score, _ = agent.simulate(visualize=False)
record_list.append(total_score)

print(f'Average score in 100 episode {np.mean(record_list):.2f}')

Q網(wǎng)絡(luò)搭建

import tensorflow as tf

L = tf.keras.layers

def create_network_model(input_shape: np.ndarray,
action_space: np.ndarray,
learning_rate=0.001) -> tf.keras.Sequential:
model = tf.keras.Sequential([
L.Dense(512, input_shape=input_shape, activation="relu"),
L.Dense(256, input_shape=input_shape, activation="relu"),
L.Dense(action_space)
])
model.compile(loss="mse",
optimizer=tf.optimizers.Adam(lr=learning_rate))
return model

經(jīng)驗(yàn)回放實(shí)現(xiàn)

經(jīng)驗(yàn)回放是一種在深度強(qiáng)化學(xué)習(xí)中常用的技術(shù),主要用于打破數(shù)據(jù)的相關(guān)性和減少過擬合。

在強(qiáng)化學(xué)習(xí)中,代理通常會(huì)在訓(xùn)練過程中與環(huán)境進(jìn)行大量交互,經(jīng)驗(yàn)回放允許代理存儲(chǔ)這些經(jīng)驗(yàn),并在后續(xù)的訓(xùn)練中反復(fù)利用這些數(shù)據(jù)。這種機(jī)制有助于改善學(xué)習(xí)效率,減少數(shù)據(jù)樣本間的時(shí)間相關(guān)性,提高訓(xùn)練過程的穩(wěn)定性。

import random
import numpy as np
from collections import namedtuple

# 代表每一個(gè)樣本的 namedtuple,方便存儲(chǔ)和讀取數(shù)據(jù)
Experience = namedtuple('Experience', ('state', 'action', 'reward', 'next_state', 'done'))

class ReplayMemory:

def __init__(self, max_size):
self.max_size = max_size
self.memory = []

def append(self, state, action, reward, next_state, done):
"""記錄一個(gè)新的樣本"""
sample = Experience(state, action, reward, next_state, done)
self.memory.append(sample)
# 只留下最新記錄的 self.max_size 個(gè)樣本
self.memory = self.memory[-self.max_size:]

def sample(self, batch_size):
"""按照給定批次大小取樣"""
samples = random.sample(self.memory, batch_size)
batch = Experience(*zip(*samples))

# 轉(zhuǎn)換數(shù)據(jù)為 numpy 張量返回
states = np.array(batch.state)
actions = np.array(batch.action)
rewards = np.array(batch.reward)
states_next = np.array(batch.next_state)
dones = np.array(batch.done)

return states, actions, rewards, states_next, dones

def __len__(self):
return len(self.memory)

DQNAgent實(shí)現(xiàn)

DQNAgent類是DQN算法的核心實(shí)現(xiàn)。它包含以下關(guān)鍵部分:

1、初始化:初始化環(huán)境、神經(jīng)網(wǎng)絡(luò)模型和經(jīng)驗(yàn)回放緩存。
2、行為選擇(choose_action):根據(jù)當(dāng)前狀態(tài)和ε-greedy策略選擇行為。
3、經(jīng)驗(yàn)回放(replay):從記憶中隨機(jī)抽取小批量經(jīng)驗(yàn)進(jìn)行學(xué)習(xí)。
4、訓(xùn)練(train):進(jìn)行多個(gè)episode的訓(xùn)練。

from IPython import display
from PIL import Image

# 定義超參數(shù)
LEARNING_RATE = 0.001
GAMMA = 0.99
EPSILON_DECAY = 0.995
EPSILON_MIN = 0.01

class DQNAgent:
def __init__(self, env_name):
self.env = gym.make(env_name)
self.observation_shape = self.env.observation_space.shape
self.action_count = self.env.action_space.n
self.model = create_network_model(self.observation_shape, self.action_count)
self.memory = ReplayMemory(500000)
self.epsilon = 1.0
self.batch_size = 64

def choose_action(self, state, epsilon=None):
"""
根據(jù)給定狀態(tài)選擇行為
- epsilon == 0 完全使用模型選擇行為
- epsilon == 1 完全隨機(jī)選擇行為
"""
if epsilon is None:
epsilon = self.epsilon
if np.random.rand() < epsilon:
return np.random.randint(self.action_count)
else:
q_values = self.model.predict(np.expand_dims(state, axis=0))
return np.argmax(q_values[0])

def replay(self):
"""進(jìn)行經(jīng)驗(yàn)回放學(xué)習(xí)"""

# 如果當(dāng)前經(jīng)驗(yàn)池經(jīng)驗(yàn)數(shù)量少于批次大小,則跳過
if len(self.memory) < self.batch_size:
return

states, actions, rewards, states_next, dones = self.memory.sample(self.batch_size)
q_pred = self.model.predict(states)

q_next = self.model.predict(states_next).max(axis=1)
q_next = q_next * (1 - dones)
q_update = rewards + GAMMA * q_next

indices = np.arange(self.batch_size)
q_pred[[indices], [actions]] = q_update

self.model.train_on_batch(states, q_pred)

def simulate(self, epsilon=None, visualize=True):
records = []
state = self.env.reset()
is_done = False
total_score = 0
total_step = 0
while not is_done:
action = self.choose_action(state, epsilon)
state, reward, is_done, info = self.env.step(action)
total_score += reward
total_step += 1

rgb_array = self.env.render(mode='rgb_array')
records.append((rgb_array, action, reward, total_score))

if visualize:
display.clear_output(wait=True)
img = Image.fromarray(rgb_array)
# 當(dāng)前 Cell 中展示圖片
display.display(img)
print(f'Action {action} Action reward {reward:.2f} | Total score {total_score:.2f} | Step {total_step}')

time.sleep(0.01)
self.env.close()
return total_score, records

def train(self, episode_count: int, log_dir: str):
"""
訓(xùn)練方法,按照給定 episode 數(shù)量進(jìn)行訓(xùn)練,并記錄訓(xùn)練過程關(guān)鍵參數(shù)到 TensorBoard
"""
# 初始化一個(gè) TensorBoard 記錄器
file_writer = tf.summary.create_file_writer(log_dir)
file_writer.set_as_default()

score_list = []
best_avg_score = -np.inf

for episode_index in range(episode_count):
state = self.env.reset()
score, step = 0, 0
is_done = False
while not is_done:
# 根據(jù)狀態(tài)選擇一個(gè)行為
action = self.choose_action(state)
# 執(zhí)行行為,記錄行為和結(jié)果到經(jīng)驗(yàn)池
state_next, reward, is_done, info = self.env.step(action)
self.memory.append(state, action, reward, state_next, is_done)
score += reward

state = state_next
# 每 6 步進(jìn)行一次回放訓(xùn)練
# 此處也可以選擇每一步回放訓(xùn)練,但會(huì)降低訓(xùn)練速度,這個(gè)是一個(gè)經(jīng)驗(yàn)技巧
if step % 1 == 0:
self.replay()
step += 1

# 記錄當(dāng)前 Episode 的得分,計(jì)算最后 100 Episode 的平均得分
score_list.append(score)
avg_score = np.mean(score_list[-100:])

# 記錄當(dāng)前 Episode 得分,epsilon 和最后 100 Episode 的平均得分到 TensorBoard
tf.summary.scalar('score', data=score, step=episode_index)
tf.summary.scalar('average score', data=avg_score, step=episode_index)
tf.summary.scalar('epsilon', data=self.epsilon, step=episode_index)

# 終端輸出訓(xùn)練進(jìn)度
print(f'Episode: {episode_index} Reward: {score:03.2f} '
f'Average Reward: {avg_score:03.2f} Epsilon: {self.epsilon:.3f}')

# 調(diào)整 epsilon 值,逐漸減少隨機(jī)探索比例
if self.epsilon > EPSILON_MIN:
self.epsilon *= EPSILON_DECAY

# 如果當(dāng)前平均得分比之前有改善,保存模型
# 確保提前創(chuàng)建目錄 outputs/chapter_15
if avg_score > best_avg_score:
best_avg_score = avg_score
self.model.save(f'outputs/chapter_15/dqn_best_{episode_index}.h5')

訓(xùn)練

# 使用 LunarLander 初始化 Agent
agent = DQNAgent('LunarLander-v2')
import glob
# 讀取現(xiàn)在已經(jīng)記錄的日志數(shù)量,避免日志重復(fù)記錄
tf_log_index = len(glob.glob('tf_dir/lunar_lander/run_*'))
log_dir = f'tf_dir/lunar_lander/run_{tf_log_index}'

# 訓(xùn)練 2000 個(gè) Episode
agent.train(20, log_dir)

agent.model.summary()

 

推薦器件

更多器件
器件型號(hào) 數(shù)量 器件廠商 器件描述 數(shù)據(jù)手冊(cè) ECAD模型 風(fēng)險(xiǎn)等級(jí) 參考價(jià)格 更多信息
PIC32MX795F512LT-80I/PT 1 Microchip Technology Inc 32-BIT, FLASH, 80 MHz, RISC MICROCONTROLLER, PQFP100, 12 X 12 MM, 1 MM HEIGHT, LEAD FREE, PLASTIC, TQFP-100

ECAD模型

下載ECAD模型
$11.46 查看
MC9S12XEP100MAL 1 Rochester Electronics LLC 32-BIT, FLASH, 50 MHz, RISC MICROCONTROLLER, PQFP112, LQFP-112
$25.46 查看
ATXMEGA64D3-MH 1 Microchip Technology Inc IC MCU 8BIT 64KB FLASH 64QFN
$13.15 查看

相關(guān)推薦

電子產(chǎn)業(yè)圖譜