trfl

TensorFlow Reinforcement Learning

  • 所有者: google-deepmind/trfl
  • 平台:
  • 許可證: Apache License 2.0
  • 分類:
  • 主題:
  • 喜歡:
    0
      比較:

Github星跟蹤圖

TRFL

TRFL (pronounced "truffle") is a library built on top of TensorFlow that exposes
several useful building blocks for implementing Reinforcement Learning agents.

Installation

TRFL can be installed from pip with the following command:
pip install trfl

TRFL will work with both the CPU and GPU version of tensorflow, but to allow
for that it does not list Tensorflow as a requirement, so you need to install
Tensorflow and Tensorflow-probability separately if you haven't already done so.

Usage Example

import tensorflow as tf
import trfl

# Q-values for the previous and next timesteps, shape [batch_size, num_actions].
q_tm1 = tf.get_variable(
    "q_tm1", initializer=[[1., 1., 0.], [1., 2., 0.]], dtype=tf.float32)
q_t = tf.get_variable(
    "q_t", initializer=[[0., 1., 0.], [1., 2., 0.]], dtype=tf.float32)

# Action indices, discounts and rewards, shape [batch_size].
a_tm1 = tf.constant([0, 1], dtype=tf.int32)
r_t = tf.constant([1, 1], dtype=tf.float32)
pcont_t = tf.constant([0, 1], dtype=tf.float32)  # the discount factor

# Q-learning loss, and auxiliary data.
loss, q_learning = trfl.qlearning(q_tm1, a_tm1, r_t, pcont_t, q_t)

loss is the tensor representing the loss. For Q-learning, it is half the
squared difference between the predicted Q-values and the TD targets, shape
[batch_size]. Extra information is in the q_learning namedtuple, including
q_learning.td_error and q_learning.target.

The loss tensor can be differentiated to derive the corresponding RL update.

reduced_loss = tf.reduce_mean(loss)
optimizer = tf.train.AdamOptimizer(learning_rate=0.1)
train_op = optimizer.minimize(reduced_loss)

All loss functions in the package return both a loss tensor and a namedtuple
with extra information, using the above convention, but different functions
may have different extra fields. Check the documentation of each function
below for more information.

Documentation

Check out the full documentation page
here.

主要指標

概覽
名稱與所有者google-deepmind/trfl
主編程語言Python
編程語言Python (語言數: 1)
平台
許可證Apache License 2.0
所有者活动
創建於2018-08-08 14:44:11
推送於2022-12-08 18:07:05
最后一次提交2021-08-12 17:05:26
發布數0
用户参与
星數3.1k
關注者數201
派生數387
提交數123
已啟用問題?
問題數20
打開的問題數4
拉請求數0
打開的拉請求數2
關閉的拉請求數7
项目设置
已啟用Wiki?
已存檔?
是復刻?
已鎖定?
是鏡像?
是私有?