Codementor Events

How and why I built Dilated UNet with Centerline-Sampling

Published May 25, 2018
How and why I built Dilated UNet with Centerline-Sampling

About me

I am a Masters Student in the field of Deep Learning at the Technical University of Munich and have been working in the field for the past 4 years.
Currently, I am also working on Medical Imaging in the capacity of a Computer Vision Engineer, applying Deep learning on Medical Images and Computer Vision datasets.
My work mainly focusses on Semantic Segmentation and 3D to 2D registration.
Previously I had worked at Edge Networks, Bangalore, India in the capacity of an NLP Engineer where my work primarily focussed on applying Deep learning on tasks such as named entity recognition and sentence classification using LSTMs.

The problem I wanted to solve

The problem at hand was segmentation of small lung airways in a huge lung MRI that posed a huge class imbalance problem. Simply applying a standard UNET to the problem would not suffice for accurate semantic segmentation.
Moreover, due to the image being huge and 3D, fitting into the GPU is a big problem and working on 2D slices is suboptimal.

What is Dilated UNet with Centerline-Sampling?

I built an entire pipeline that was tasked with taking the image, creating relevant 3D patches of the image and run a special UNet that would enable the system to get the small lung airways.
This was achieved by implementing the following -

  1. Centerline Sampling - Random Sampling would return a lot of background samples and this would result in our network not learning anything meaningful. Sampling around the trachea and the airways specifically circumvents this problem. Moreover, to fit the batch in the GPU Memory, the patches were chosen as 64xx64x64.
  2. Dilated Convolutions - Help to capture the entire context of the image and hence reduce the adverse effects of max pooling - downsampling (loss of information).
  3. Weighted Dice Coefficient - Giving more weight to the foreground pixels as compared to the background pixels.

Tech stack

Technologies Used-

  1. Keras - For the building of the underlying UNET. Keras was used because of its intuitive API and because it is now fully integrated into Tensorflow 1.7.
  2. Tensorflow- For building the data pipeline. Has inbuilt functions and classes that help build efficient data pipelines.
  3. Numpy - Basic matrix manipulations
  4. h5py - To deal with the h5 files of the MRI.
  5. scipy - Used for the centerline sampling. Has a handy methos for centerline extraction whid makes the subsequent sampling easy.

The process of building Dilated UNet with Centerline-Sampling

The standard procedure for any basic semantic segmentation project is to apply a baseline algorithm and evaluate.
In my case, the baseline algorithm in question was the UNet.
However, over multiple evaluations and small tweaks and variations, the performance of the network was abysmal. This led me to read relevant papers and projects and implement relevant ideas from each of the tehm and conduct ablation studies.

Challenges I faced

Actually working with huge 3D datasets is hard.
Before using the centerline sampling, we were using the random sampling and the performance was not improving even after applying several tweaks and tricks.
The main problem with deep learning today is debugging and understanding what is exactly going on underneath and random sampling makes it even tougher.
Moreover, using max pooling really reduces the resolution and results in loss of information but is necessary in order to make the network learn something. Therefore, the use of dilated convolutions was required.

Key learnings

I primarily learned the importance of testing each and every addition made to the network. Over the course of tweaking, often I forgot which feature or tweak made a difference.
If I had to start again, I would probably make a log of all the features and the corresponding differnece to the final performance they made.

Tips and advice

  1. Always start with a baseline and only then started tweaking your architecuture or try an absolutely new architecture.
  2. Data is the king. Try to observe what kind of data you are feeding into the network and try to come up with strategies that would provide the best possible data to the network.
  3. Always test your implemntation over various datasets

Final thoughts and next steps

I was able to successfully deliver the project to the client. One thing I observed was that the ground truth labels were not as accurate as previously thought. So one direction I would like to work on is a Variational UNet which does not really require the most accurate labels.

Discover and read more posts from Navneet M Kumar
get started
post commentsBe the first to share your opinion
Show more replies