Saturday 20 May 2017

How ConcurrentHashMap Works Internally in Java

Background

In one of the previous posts we saw how HashMap works -
and how it's time complexity of insertion and deletion is O(1) is normal case. Though this is a great data structure to work with in terms of time complexity it is not thread safe which means you cannot use it directly in multi threaded environments without taking additional precautions like synchronizing put/get on your own. Instead Java has provided a thread safe implementation of concurrent hashmap. We can directly use it in case of multi threaded environments for thread safety. Eg. in case of parallel stream introduced in java 8.

How ConcurrentHashMap Works Internally in Java

Before we see how it is implemented in Java lets give it some though. What are possible problems with a HashMap. Race condition, invalid state. Lets say two writes happen at the same time. Since write is not an atomic operation one value may overwrite other and Map may go in inconsistent state. We can obviously add synchronization over read/writes of a HashMap but it would be very inefficient and have performance impact. I would be like single threaded application certainly the behavior we don't expect. To solve this issue Java provides ConcurrentHashMap that has built in thread safety. Let see how -

We know how HashMap works. Internally it stores an array of Entry object which essentially has key, value and pointer to next Entry object (linked list used in case of collision). You can think of each array element as bucket and each Entry object as a data point containing key (in case 2 keys have same hash - collision), value  and pointer to next data element. 

Working :
ConcurrentHashMap as the name suggests allows concurrent read/writes to the Map. But there are limitations. ConcurrentHashMap maintains another data structure internally called segments. Each bucket of HashMap is part of one of the segments. Number of segments is called Concurrency-Level which determines number of thread that can write simultaneous. This Segments gets locked when writing/updating/removing data. Think of Segments as locks used to prevent concurrent write to same bucket of hashmap leading to inconsistency. So as long as write to concurrent hashmap is on different segments it can happen in parallel. Reads are completely lock free i.e No need to acquire lock for reading. Last updated value is returned.


 Now lets go step by step -

 Concurrency-Level , Segment array and initialization :
  • First when you create a ConcurrentHashMap you can provide concurrency level. This determines size of Segment array. Size of segment array will always be equal or more than the concurrency level. If this is not provided default is used - 
    • static final int MAX_SEGMENTS = 1 << 16; // slightly conservative
  • Note that the size of segment table will always be power of 2. So if you give  concurrency level as 10 then next best power of 2 match will be picked up i.e 16 and Segment array of size 16 will be created which implies 16 threads can simultaneously operate on the map.
static final class Segment<K,V> extends ReentrantLock implements Serializable {

    //The number of elements in this segment's region.
    transient volatile int count;
    //The per-segment table. 
    transient volatile HashEntry<K,V>[] table;
}

Putting element in ConcurrentHashMap :

  • For putting element in Map we first need to determine which segment the element should be processed for. For this we first get hascode of the key. Next we do a rehash of the existing hash to ensure
     /**
     * Applies a supplemental hash function to a given hashCode, which
     * defends against poor quality hash functions.  This is critical
     * because ConcurrentHashMap uses power-of-two length hash tables,
     * that otherwise encounter collisions for hashCodes that do not
     * differ in lower or upper bits.
     */
    private static int hash(int h) {
        // Spread bits to regularize both segment and index locations,
        // using variant of single-word Wang/Jenkins hash.
        h += (h <<  15) ^ 0xffffcd7d;
        h ^= (h >>> 10);
        h += (h <<   3);
        h ^= (h >>>  6);
        h += (h <<   2) + (h << 14);
        return h ^ (h >>> 16);
    }
  •  Once hash is calculated you can get the segment which it belongs to and delegate put method to segments put method as follows -
    public V put(K key, V value) {
        if (value == null)
            throw new NullPointerException();
        int hash = hash(key.hashCode());
        return segmentFor(hash).put(key, hash, value, false);
    }

    final Segment<K,V> segmentFor(int hash) {
        return segments[(hash >>> segmentShift) & segmentMask];
    }


We will see how segment is computed in some time with a proper example. Once put is delegated to segment , segment will add it to the appropriate bucket in the segment.

        V put(K key, int hash, V value, boolean onlyIfAbsent) {
            lock();
            try {
                int c = count;
                if (c++ > threshold) // ensure capacity
                    rehash();
                HashEntry<K,V>[] tab = table;
                int index = hash & (tab.length - 1);
                HashEntry<K,V> first = tab[index];
                HashEntry<K,V> e = first;
                while (e != null && (e.hash != hash || !key.equals(e.key)))
                    e = e.next;

                V oldValue;
                if (e != null) {
                    oldValue = e.value;
                    if (!onlyIfAbsent)
                        e.value = value;
                }
                else {
                    oldValue = null;
                    ++modCount;
                    tab[index] = new HashEntry<K,V>(key, hash, first, value);
                    count = c; // write-volatile
                }
                return oldValue;
            } finally {
                unlock();
            }
        }


Now this is very interesting method. Lets understand whats happening here.

  • First call is to lock(). Since it is a write/update operation on a bucket of same segment we need a lock. If you recollect Segment class it extends ReentrantLock so each segment is a lock. So you can call lock() and unlock() directly in Segment class.
  • Next it's like a normal HashMap. You find the index of the Entry table where your elements hash falls and add it there as linked list.
  • You can see similar code as HashMap that updates value if key is same, inserts in array if there is no element in the table and adds it in the linked list of the table if element already exists.
  • Finally once operation is complete it calls unlock() so that other threads can continue update.
  • Note the lock is a blocking call. 
  • You can also see call for rehash if threshold is reached. Like Entry array Segment also has a threshold and when it is reached Segment array is resized for performance. That's what rehash. 
NOTE : For getting index of Segment table first n bits are used where as for getting index of Entry table last N bits are used from enhanced hash integer (See details in example below).

Getting element from  ConcurrentHashMap : 

Get on ConcurrentHashMap is very simple no locks involved. You simply read the data and return -

        public V get(Object key) {
                int hash = hash(key.hashCode());
                return segmentFor(hash).get(key, hash);
        }

        V get(Object key, int hash) {
            if (count != 0) { // read-volatile
                HashEntry<K,V> e = getFirst(hash);
                while (e != null) {
                    if (e.hash == hash && key.equals(e.key)) {
                        V v = e.value;
                        if (v != null)
                            return v;
                        return readValueUnderLock(e); // recheck
                    }
                    e = e.next;
                }
            }
            return null;
        }


NOTE  : readValueUnderLock method is used as a backup in case a null (pre-initialized) value is ever seen in an unsynchronized access method.

Example

Above was just all code and some understanding. Now lets take an actual example.

Let's say we have created a ConcurrentHashMap with concurrency level lets say 10. Based on this Segment array will be created based on following code -

    private static void printSegmentDetails(int concurrencyLevel) {
        int sshift = 0;
        int segmentMask = 0;
        int segmentShift = 0;

        int ssize = 1;
        while (ssize < concurrencyLevel) {
            ++sshift;
            ssize <<= 1;
        }
        segmentShift = 32 - sshift;
        segmentMask = ssize - 1;
        System.out.println("Segment array size :" + ssize);
        System.out.println("segmentShift : " + segmentShift);
        System.out.println("segmentMask : " + segmentMask);
    }


Output for 10 concurrency level:
Segment array size : 16
segmentShift : 28
segmentMask : 15

NOTE  :As mentioned before segment array is of size 2^n such that 2^n >= concurrency level. In this case 2^4

Now that we have segment table in place lets simulate put. We need to put a String called "Aniket" as key. We don't care about value. Just make sure it's not null.

  1. First we will calculate hascode of the key.
  2. Then hash it so for better hash (as mentioned above)
  3. Then based on the result hash we will find which segment will it belong
Remember of Segment table was >= 2^N we now want first N bits to determine which segment this hash falls into. Since N bits will vary from 1 - 2^N which is our segment array size. Also remember code to get this index from above? -
  • int segmentIndex = (hash >>> segmentShift) & segmentMask
This essentially means logically right shift hash with segmentShift bits. Since int is 32 bit and segmentShift = 32 - sshift, hash >>> segmentShift will essentially give you first sshift bits (sshift is nothing but N in 2^N we saw above). segmentMask is to get the N bits post shift.

So in this case,
N  = sshift =  4
2^N = 16 -> Size of segment array
segmentShift = 32 - 4 = 28 (as we saw in output above)
segmentMask = 16 -1 - 15

    public static void main(String args[]) {    
        String key = "Aniket";
        //hascode of key
        System.out.println(key.hashCode());
        //better hash
        System.out.println(hash(key.hashCode()));
        //better hash in binary
        System.out.println(Integer.toBinaryString(hash(key.hashCode())));
        //logical right shift by segmentShift
        System.out.println("Right shifter hash : " + Integer.toBinaryString(hash(key.hashCode()) >>> 28));
        // segment index as binary and of right shift and segmentMask
        System.out.println("Segment Index : " + Integer.toBinaryString((hash(key.hashCode()) >>> 28 ) & 15));
        // segment index as decimal
        System.out.println("Segment Index : " + ((hash(key.hashCode()) >>> 28 ) & 15));
    }


Output :
1965716254
1839402854
1101101101000110000111101100110
Right shifter hash : 110
Segment Index : 110
Segment Index : 6


NOTE : 1101101101000110000111101100110 is 31 bits as rightmost bit is 0 and ignored.  Same goes for all subsequent binmary bit formats.

So your element with key "Aniket" will go in Segment array of index 6. Inside segments it's pretty simple to calculate index of Entry array.

  •  int entryArrayindex = hash & (tab.length - 1);
         int entryArrayindex = (hash(key.hashCode()) & (16 - 1));
         System.out.println("Entry array index : " + entryArrayindex);
         System.out.println("Entry array index in binary : " + Integer.toBinaryString(entryArrayindex));


Output :
Entry array index : 6
Entry array index in binary : 110


So finally Entry is inserted at index 6 of Entry table.

So to summarize for getting index of Segment table first n bits are used where as for getting index of Entry table last N bits are used from enhanced hash integer.


Related Links 

Left shift and right shift operators in Java

Left shift and right shift operators in Java

A lot of time we use >> , >>> or << and <<< operators in Java for bit operations. These operations essentially shift bits left or right depending on the operator. In this post we will see what exactly is the difference.
  • >> is arithmetic or signed shift right
  • >>> is logical or unsigned shift right
  • << is arithmetic or signed shift left
The signed left shift operator "<<" shifts a bit pattern to the left, and the signed right shift operator ">>" shifts a bit pattern to the right. The bit pattern is given by the left-hand operand, and the number of positions to shift by the right-hand operand.

The unsigned right shift operator ">>>" shifts a zero into the leftmost position, while the leftmost position after ">>" depends on sign extension.

NOTE : There is no logic left shift as it is same as arithmetic left shift.

Example :

        System.out.println(Integer.toBinaryString(-121));
        // prints "11111111111111111111111110000111"
        System.out.println(Integer.toBinaryString(-121 >> 1));
        // prints "11111111111111111111111111000011"
        System.out.println(Integer.toBinaryString(-121 >>> 1));
        // prints "1111111111111111111111111000011"
        
        System.out.println(Integer.toBinaryString(121));
        // prints "1111001"
        System.out.println(Integer.toBinaryString(121 >> 1));
        // prints "111100"
        System.out.println(Integer.toBinaryString(121 >>> 1));
        // prints "111100"


As you can see in case of -121 since it is a negative number arithmetic or signed shift right adds a 1 to the rightmost bit where as in case if 121 it adds 0.

logical or unsigned shift does not care about sign. It just adds 0 to the shifted bits.


Related Links

t> UA-39527780-1 back to top