Write a post

How I Built a Reverse Image Search with Machine Learning and TensorFlow: Part 3

Published Jun 07, 2017
How I Built a Reverse Image Search with Machine Learning and TensorFlow: Part 3

Welcome Back Again...

I’ve been making some TensorFlow examples for my website, fomoro.com, and one of the ones I created was a lightweight reverse image search. While it’s fresh in my head, I wanted to write up an end-to-end description of what it’s like to build a machine learning app, and more specifically, how to make your own reverse image search. For this demo, the work is ⅓ data munging/setup, ⅓ model development and ⅓ app development.

At a high-level, I use TensorFlow to create an autoencoder, train it on a bunch of images, use the trained model to find related images, and display them with a Flask app.

In the last post, I talked model development and training. Now that I have a trained model, let’s talk about how to use it.

Read Part 1: Project Setup
Read Part 2: Model Development

Ready? Let’s get started…

Using The Model with Checkpoints

Using a model for inference is a lot like training a model — you just don’t tell it if it’s right or not. There are all sorts of ways you can use one in production. For high-performance inference, TensorFlow Serving is the way to go, but it’s not the easiest thing to get setup.

Since this is a demo, I’m going the easy route and simply using the final model checkpoint. Checkpoints are great because you can stop and restart training, so I use them a lot. Fortunately, checkpoint handling is built right into train_and_evaluate, part of the Experiment class, which is another point in favor of using them instead of the more common session.run(). You can configure how often checkpoints are saved with the RunConfig:

run_config = tf.contrib.learn.RunConfig(

experiment.py line:33


In order to use my existing checkpoints, I created a new file, predict.py. This is a purpose built file, and only really relevant to this specific use case unlike the model and experiment files which I include in almost every project.

The job of the predict script is to load a checkpointed model, run a bunch of images through that model, compare their embeddings, and finally save the results into a couple files I use directly in my app. Unlike a classifier problem where I would use the model directly, I ultimately need to know which images are related to which other images. Running them all through in a batch allows me to associate them together without having to also keep all the embeddings around.

In the code below, all I had to do was give my estimator the location of my model checkpoint directory, and it took care of all the heavy lifting. Boom. Done.

estimator = tf.contrib.learn.Estimator(

predictions_iter = estimator.predict(

predict.py line:53

From TensorFlow to Python

Now that I’ve got all my predictions, it’s time to drop out of TensorFlow land and back into regular Python. I turn the results of my model into a list, and I can now do normal Python things with the results, like iterate over them or get the total number of features in my encoded image.

# drop out of tensorflow into regular python/numpy
predictions_list = list(predictions_iter)
features_length = len(predictions_list[0]['encoded_image'].flatten())

predict.py line:62

Nearest Neighbor Comparisons

I’m using a really handy library, Annoy, to do my nearest neighbor comparisons. It’s an approximate search so it trades accuracy for speed, but that’s a compromise I like in this situation. The downside to that speed is that it indexes by integers, so I also needed to build the filename associations. I also needed to know the exact length of the features that I’m saving ahead of time (and all the features have to be the same length). But in the grand scheme of the code those are small potatoes, and I’m happy not have to write a search from scratch.

While I’m giving shout-outs to good libraries, I’d also like to mention tqdm, which is a really easy way to add progress bars to the command line. In my case, it’s probably a bit of overkill, but it’s still really nice to see what my script is doing instead of staring at a blank prompt.

# build search and filename indexes
filenames = []
nn_search = AnnoyIndex(features_length)
for i in tqdm(range(len(predictions_list))):
    nn_search.add_item(i, embeddings_norm[i])

predict.py line:77

The last step is to build our search tree and save all the results into a couple of files to use in our app. We could save the index into a database. However, for something this size, adding a database seems like more overkill. I’m also saving the features length as part of my metadata because I’ll need it when I use Annoy in my app. We could also pre-process all the neighbors, but we’d have to either choose less flexibility in our final app or flat file that’s a overly large.

# build and save filename metadata
with open('{}/metadata.json'.format(args.out_dir), 'w') as outfile:
        'timestamp': time.time(),
        'features_length': features_length,
        'filenames': filenames
    }, outfile)

# build and save search trees

predict.py line:81

App Setup

Now for the finishing bits and setting up a basic app. I created a new app directory with it’s own data directory, and that’s where I saved my annoy and metadata files. My project and expanded app directory now look like this:

--------other files...
----other files…


Before I built out my app, I created a helper class to take care of dealing with Annoy and doing the lookups. It loads both of our files, sets a couple variables, and returns the nearest neighbors of a given valid index_id when asked. Since I always want to return something, I have it pick a random index_id if it gets one that it doesn’t recognise.

class AnnoyLookup(object):
    def __init__(self, 

        with open(metadata_path) as f:
            self._data = json.load(f)

        self._limit = len(self._data['filenames'])
        self._index = AnnoyIndex(self._data['features_length'])

    def get_neighbors(self, image_id, max_neighbors=13):
        results = []

        if image_id < 0 or image_id >= self._limit:
            image_id = random.randrange(self._limit)

        for item_id in self._index.get_nns_by_item(image_id,
                'id': item_id,
                'image': self._data['filenames'][item_id]

        return results

annoy_lookup.py line:15

Flask Setup

We’re finally ready to build our app. I love Flask for its simplicity. This was literally the easiest part of the project to write. Which was nice because I hate saving the hard stuff for last — it almost always means that shipping gets delayed. Since I don’t want to move my images into my app directory from my data directory, I use Flask’s built in send_from_directory function to return them from the server. It also lets me use any path I want to serve them from.

def index_route():
    results = lookup.get_multiple_neighbors(-1) # random starting image.
    return render_template('index.html', results=results)

@app.route('/nearest/<int:image_id>', methods=['GET'])
def get_nearest_html_route(image_id):
    results = lookup.get_multiple_neighbors(image_id)
    return render_template('index.html', results=results)

def get_data_route(path):
    return send_from_directory('../data/results/', path)

app.py line:14

The template is just basic html with bootstrap thrown in so it only mostly looks like programmer art. Here’s a link to the source.

Extra Credit


Since the app was so fast to build, I ended up with some extra time on my hands. I used it to get the app up and running with a JavaScript front-end so I could host it on GitHub Pages. This required adding an API route to the server and offloading a little bit of display logic to some js.

@app.route('/api/nearest/<int:image_id>', methods=['GET'])
def get_nearest_api_route(image_id):
    results = lookup.get_multiple_neighbors(image_id)
    return jsonify(results=results)

app.py line:24

You can see the finished app at http://fomorians.github.io/imagesearch/.


So that’s it. From setup to model creation to app development. That’s how I used machine learning and TensorFlow to create a reverse image search. It was fun writing it all up. I hope you got something out of it too.

Questions? Comments? Let me know in the comments, or hit me up on Twitter: @jimmfleming

Discover and read more posts from Jim Fleming
get started
Enjoy this post?

Leave a like and comment for Jim

Be the first to share your opinion

Subscribe to our weekly newsletter