trfl

TensorFlow Reinforcement Learning

  • Owner: google-deepmind/trfl
  • Platform:
  • License:: Apache License 2.0
  • Category::
  • Topic:
  • Like:
    0
      Compare:

Github stars Tracking Chart

TRFL

TRFL (pronounced "truffle") is a library built on top of TensorFlow that exposes
several useful building blocks for implementing Reinforcement Learning agents.

Installation

TRFL can be installed from pip with the following command:
pip install trfl

TRFL will work with both the CPU and GPU version of tensorflow, but to allow
for that it does not list Tensorflow as a requirement, so you need to install
Tensorflow and Tensorflow-probability separately if you haven't already done so.

Usage Example

import tensorflow as tf
import trfl

# Q-values for the previous and next timesteps, shape [batch_size, num_actions].
q_tm1 = tf.get_variable(
    "q_tm1", initializer=[[1., 1., 0.], [1., 2., 0.]], dtype=tf.float32)
q_t = tf.get_variable(
    "q_t", initializer=[[0., 1., 0.], [1., 2., 0.]], dtype=tf.float32)

# Action indices, discounts and rewards, shape [batch_size].
a_tm1 = tf.constant([0, 1], dtype=tf.int32)
r_t = tf.constant([1, 1], dtype=tf.float32)
pcont_t = tf.constant([0, 1], dtype=tf.float32)  # the discount factor

# Q-learning loss, and auxiliary data.
loss, q_learning = trfl.qlearning(q_tm1, a_tm1, r_t, pcont_t, q_t)

loss is the tensor representing the loss. For Q-learning, it is half the
squared difference between the predicted Q-values and the TD targets, shape
[batch_size]. Extra information is in the q_learning namedtuple, including
q_learning.td_error and q_learning.target.

The loss tensor can be differentiated to derive the corresponding RL update.

reduced_loss = tf.reduce_mean(loss)
optimizer = tf.train.AdamOptimizer(learning_rate=0.1)
train_op = optimizer.minimize(reduced_loss)

All loss functions in the package return both a loss tensor and a namedtuple
with extra information, using the above convention, but different functions
may have different extra fields. Check the documentation of each function
below for more information.

Documentation

Check out the full documentation page
here.

Main metrics

Overview
Name With Ownergoogle-deepmind/trfl
Primary LanguagePython
Program languagePython (Language Count: 1)
Platform
License:Apache License 2.0
所有者活动
Created At2018-08-08 14:44:11
Pushed At2022-12-08 18:07:05
Last Commit At2021-08-12 17:05:26
Release Count0
用户参与
Stargazers Count3.1k
Watchers Count201
Fork Count387
Commits Count123
Has Issues Enabled
Issues Count20
Issue Open Count4
Pull Requests Count0
Pull Requests Open Count2
Pull Requests Close Count7
项目设置
Has Wiki Enabled
Is Archived
Is Fork
Is Locked
Is Mirror
Is Private