tensorflow-triplet-loss

Implementation of triplet loss in TensorFlow

Github星跟踪图

Triplet loss in TensorFlow Build Status

Author: Olivier Moindrot

This repository contains a triplet loss implementation in TensorFlow with online triplet mining.
Please check the blog post for a full description.

The code structure is adapted from code I wrote for CS230 in this repository at tensorflow/vision.
A set of tutorials for this code can be found here.

Requirements

We recommend using python3 and a virtual environment.
The default venv should be used, or virtualenv with python3.

python3 -m venv .env
source .env/bin/activate
pip install -r requirements_cpu.txt

If you are using a GPU, you will need to install tensorflow-gpu so do:

pip install -r requirements_gpu.txt

Triplet loss, triplet-loss-img, :--:, Triplet loss on two positive faces (Obama) and one negative face (Macron), The interesting part, defining triplet loss with triplet mining can be found in model/triplet_loss.py.

Everything is explained in the blog post.

To use the "batch all" version, you can do:

from model.triplet_loss import batch_all_triplet_loss

loss, fraction_positive = batch_all_triplet_loss(labels, embeddings, margin, squared=False)

In this case fraction_positive is a useful thing to plot in TensorBoard to track the average number of hard and semi-hard triplets.

To use the "batch hard" version, you can do:

from model.triplet_loss import batch_hard_triplet_loss

loss = batch_hard_triplet_loss(labels, embeddings, margin, squared=False)

Training on MNIST

To run a new experiment called base_model, do:

python train.py --model_dir experiments/base_model

You will first need to create a configuration file like this one: params.json.
This json file specifies all the hyperparameters for the model.
All the weights and summaries will be saved in the model_dir.

Once trained, you can visualize the embeddings by running:

python visualize_embeddings.py --model_dir experiments/base_model

And run tensorboard in the experiment directory:

tensorboard --logdir experiments/base_model

Here is the result (link to gif):, embeddings-img, :--:, Embeddings of the MNIST test images visualized with T-SNE (perplexity 25), ## Test

To run all the tests, run this from the project directory:

pytest

To run a specific test:

pytest model/tests/test_triplet_loss.py

Resources

主要指标

概览
名称与所有者omoindrot/tensorflow-triplet-loss
主编程语言Python
编程语言Python (语言数: 1)
平台
许可证MIT License
所有者活动
创建于2018-03-13 07:05:04
推送于2019-05-09 18:20:20
最后一次提交2019-01-28 11:40:41
发布数0
用户参与
星数1.1k
关注者数37
派生数283
提交数40
已启用问题?
问题数61
打开的问题数32
拉请求数2
打开的拉请求数0
关闭的拉请求数2
项目设置
已启用Wiki?
已存档?
是复刻?
已锁定?
是镜像?
是私有?