Saturday, August 12, 2017

A Trie


This is a trie that uses a sentinel node to denote the end of a word. This is more space efficient than having to flag each node as to whether it denotes an end of a word. To quickly find the number of prefix matches, it stores the prefix count in the node.








class Trie {
    char ch;
    int count = 0;
    Map<Character, Trie> list = new HashMap<Character, Trie>();
    
    public Trie(char ch) {
        this.ch = ch;
    }
    
    public Trie add(char ch) {
        Trie node = this.list.get(ch);
        if (node == null) {
            Trie newNode = new Trie(ch);
            this.list.put(ch, newNode);
            node = newNode;
        }

        //adding the count to the current node is preferable
        //to adding to the node that matches the character.
        //This way, we won't add to the sentinel node
        //and we add only in one place.
this.count++; return node; } public int size() { return this.count; } private Trie findChar(char ch) { return this.list.get(ch); } public boolean findWord(String word) { Trie node = this; for (char ch: word.toCharArray()) { node = node.findChar(ch); if (node == null) { return false; } }
        //we may have found a prefix, make sure it is a word
        //if it's a word, the list must have the sentinel.
        return node.list.get((char)0) != null;
    }
    
    public int findPartial(String prefix) {
        Trie node = this;
        for (char ch : prefix.toCharArray()) {
            node = node.list.get(ch);
            if (node == null) {
                return 0;
            }
        }
        return node.size(); 
    }
    
    public void add(String s) {
        Trie node = this;
        for (char ch : s.toCharArray()) {
            node = node.add(ch);
        }
        //add the sentinel to mark the end of the word.
        node.add((char)0);
    }
}

Now it is possible to reduce the space taken by the trie further by using an array instead of the map. Knowing that we need to use only lower case letters, we can use the charater before 'a' as the sentinel, so that the array length is set to 27.

Another space optimization comes about by using a single word (32 bits) to store both the character and the prefix count. Java uses two bytes for the char type, and we could do with one byte. But that still uses 5 bytes per Trie node, but we don't need the 2 billion range possible with 32 bits to represent the count of all prefixes for any English substring.

The prefix count is highest on the root node, as all words have the head node character as the prefix. So the highest prefix count is the number of words in the dictionary. This is generally never more than 250, 000. We can safely use 24 bits which can represent 8 million as a signed integer.

So we can combine the character and the prefix count to a single word.

Is there anything else we could do? Yes - we could read all the words into our Trie and trim the list on each Trie node. This results from the observation that we rarely use all the slots in our list - Especially as the trie spans out, there are fewer number of new words. Thus we could find the last used index on the list, and create a new shorter list.

Doing all of these drops the size of the trie from ~ 228M to ~ 68M.

Here is an implementation.

I store a random word list in pastebin for testing - there is code here that uses this, as well as pulling a dictionary of lower case words.  If you use this, you will need to make sure the dictionary you substitute has only lower case words, so some pre-processing might be necessary - in particular, you are likely to find the hyphen (-) in some word which you will need to remove.

Last but not least, the memory stats don't give an idea of the space saving due to garbage collector not being deterministic. I use the sizeInBytes() to recursively calculate the memory foot print of the Trie.



  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
public class Trie {

    interface GetAndPut {
        public void put(Character ch, Trie trie);
        public Trie get(Character ch);
        public int lastUsedIndex();
        public Trie[] children();
        public void trim();
    }
    class SuffixCharsWithMap implements  GetAndPut {
        Map<Character, Trie> list = new HashMap<Character, Trie>();
        public void put(Character ch, Trie node) {
            list.put(ch, node);
        }
        public Trie get(Character ch) {
            return list.get(ch);
        }
        public int lastUsedIndex() {
            return list.size()-1;
        }
        public Trie[] children() {
            return list.values().toArray(new Trie[list.size()]);
        }
        public void trim() {

        }
    }
    class SuffixCharsWithArray implements GetAndPut {
        public Trie[] list;

        public SuffixCharsWithArray() {
            int sz = (int)('z') - (int)'`' +1;
            list = new Trie[sz];
        }
        public void put(Character ch, Trie node) {
            list[(int)ch - (int)'`'] = node;
        }
        public Trie get(Character ch) {
            try {
                return list[(int) ch - (int) '`'];
            } catch (ArrayIndexOutOfBoundsException e) {
                // we hit an index that got trimmed out
                return null;
            }
        }
        public int lastUsedIndex() {
            int lue = -1;
            for (int i=0; i<list.length; i++) {
                if (list[i] != null) {
                    lue = i;
                }
            }
            return lue;
        }
        public void trim() {
            if (lastUsedIndex()+1 < list.length) {
                Trie[] newList = new Trie[lastUsedIndex() + 1];
                for (int i = 0; i < newList.length; i++) {
                    newList[i] = list[i];
                }
                list = newList;
            }
        }
        public Trie[] children() {
            List<Trie> l = new ArrayList<Trie>();
            for (Trie t: list) {
                if (t != null && t.getChar() != '`') {
                    l.add(t);
                }
            }
            return l.toArray(new Trie[l.size()]);
        }
    }

    //store char and the prefix count using 32 bits
    //the first byte is the character, the next 3 bytes get the prefix count
    //3 bytes can hold ~ 16 million, and there aren't that many english
    //words. the total word count is less than 250,000, and the prefix count
    //of any substring is less than that.
    private int count = 0;

    public char getChar() {
        return (char)(count & 0xFF000000 >> 24);
    }

    public void setChar(char ch) {
        count = ((int)ch) << 24 | (count & 0x00FFFFFF);
    }

    public int getCount() {
        return count & 0x00FFFFFF;
    }

    //this is safe as the count will never get high enough
    //to push over int the most significant byte holding the character
    public void incCount() {
        count++;
    }

    GetAndPut suffixChars = new SuffixCharsWithArray();
    //GetAndPut suffixChars = new SuffixCharsWithMap();

    public Trie(char ch) {
        this.setChar(ch);
    }

    public Trie add(char ch) {
        Trie node = this.suffixChars.get(ch);
        if (node == null) {
            Trie newNode = new Trie(ch);
            this.suffixChars.put(ch, newNode);
            node = newNode;
        }

        //adding the count to the current node is preferable
        //to adding to the node that matches the character.
        //This way, we won't add to the sentinel node
        //and we add only in one place.

        this.incCount();
        return node;
    }

    public int size() {
        return this.count;
    }

    private Trie findChar(char ch) {
        return this.suffixChars.get(ch);
    }

    public boolean findWord(String word) {
        Trie node = this;
        for (char ch: word.toCharArray()) {
            node = node.findChar(ch);
            if (node == null) {
                return false;
            }
        }
        //we may have found a prefix, make sure it is a word
        //if it's a word, the list must have the sentinel.
        return node.suffixChars.get('`') != null;
    }

    public int findPartial(String prefix) {
        Trie node = this;
        for (char ch : prefix.toCharArray()) {
            node = node.suffixChars.get(ch);
            if (node == null) {
                return 0;
            }
        }
        return node.size();
    }

    public void add(String s) {
        Trie node = this;
        for (char ch : s.toCharArray()) {
            node = node.add(ch);
        }
        //add the sentinel to mark the end of the word.
        node.add('`');
    }

    private void walk(int[] indices) {
        indices[this.suffixChars.lastUsedIndex()] ++;
        for (Trie ch : suffixChars.children()) {
            ch.walk(indices);
        }
    }

    public int[] lastUsedIndices() {
        int[] indices = new int[(int)'z' - (int)'`' + 1];
        walk(indices);
        return indices;
    }

    private void walk2() {
        suffixChars.trim();
        for (Trie ch : suffixChars.children()) {
            ch.walk2();
        }
    }

    static private int walk3(Trie t) {
        if (t == null) return 0;
        // 4 = size of `count`
        // 8 = size of each reference to a Trie

        int acc = 4 + 8 * (((SuffixCharsWithArray)t.suffixChars).list.length);
        for (Trie node: t.suffixChars.children()) {
            acc += walk3(node);
        }
        return acc;
    }

    public void trim() {
        walk2();
    }

    public int sizeInBytes() {
        return walk3(this);
    }

    public void read() throws FileNotFoundException {
        String wordFilePath = "/Users/thushara/lcwords.txt";
        BufferedReader br = new BufferedReader(new FileReader(wordFilePath));
        String word;
        try {
            while ((word = br.readLine()) != null) {
                add(word);
            }
        } catch (IOException e) {
            System.err.format("disk error! %s", e.getMessage());
        }
    }

    static public String getRandomWordList() throws MalformedURLException, IOException {
        Pattern alpha = Pattern.compile("^[A-Za-z]+$");
        String url = "https://pastebin.com/raw/NXH7UAr1";
        URL obj = new URL(url);
        HttpURLConnection con = (HttpURLConnection) obj.openConnection();
        con.setRequestMethod("GET");
        int responseCode = con.getResponseCode();
        BufferedReader in = new BufferedReader(
                new InputStreamReader(con.getInputStream()));
        String inputLine;
        StringBuffer response = new StringBuffer();

        while ((inputLine = in.readLine()) != null) {
            Matcher m = alpha.matcher(inputLine);
            if (m.matches()) {
                response.append(inputLine.toLowerCase());
            }
        }
        in.close();
        return response.toString();
    }

    static public void main(String[] args) throws FileNotFoundException, IOException {
        long mem1 = Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory();
        System.out.format("memory usage at start %d\n", mem1);
        Trie trie = new Trie('$');
        trie.read();
        System.out.format("size of trie in bytes: %d\n", trie.sizeInBytes());
        trie.trim();
        System.out.format("size of trimmed trie in bytes: %d\n", trie.sizeInBytes());

        Scanner in = new Scanner(System.in);
        System.out.println("type a word in lower case (upper case char to exit)> ");
        while (true) {
            String s = in.next();
            if (Character.isUpperCase(s.charAt(0))) break;
            boolean found = trie.findPartial(s) > 0;
            System.out.println(found ? "yes" : "no");
        }

        long st = System.currentTimeMillis();

        String words = getRandomWordList();

        String[] arr = words.split(" ");
        for (String s: arr) {
             if (!s.isEmpty() && !trie.findWord(s)) System.out.println("couldn't find " + s);
        }
        long elapsed = System.currentTimeMillis() - st;
        long mem2 = Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory();
        System.out.format("memory usage at end   %d\n", mem2);
        System.out.format("took %d ms for %d words using %d MB\n", elapsed, arr.length, (mem2 - mem1)/1024/1024);
    }

}

folding with python

Solving this problem the functional way =>

Find if a sorted list of positive numbers has duplicates.


>>> def has_dups(nums):
...     return reduce (lambda x,y: ( x[0] or (y == x[1]), y), nums, (False,0))[0]
... 
>>> has_dups([1])
False
>>> has_dups([1,1])
True
>>> has_dups([1,1,1])
True
>>> has_dups([1,2,4])
False
>>> has_dups([1,2,2])
True
>>> has_dups([1,2,2,3,4])
True

>>> 

From the definition of reduce:

reduce(functioniterable[initializer])
Apply function of two arguments cumulatively to the items of iterable, from left to right, so as to reduce the iterable to a single value. For example, reduce(lambda x, y: x+y, [1, 2, 3, 4, 5]) calculates ((((1+2)+3)+4)+5). The left argument, x, is the accumulated value and the right argument, y, is the update value from the iterable. If the optional initializer is present, it is placed before the items of the iterable in the calculation, and serves as a default when the iterable is empty. If initializer is not given and iterable contains only one item, the first item is returned. Roughly equivalent to:

To solve the problem, we need to check if two adjacent items have the same value. To do this the functional way, we walk through the list using the reduce operator. At each point in the list, the reduce operator applies the user supplied function to the previous output(x) and the current item(y) of the list.

We need to remember if we see two adjacent items, and pass it as the output. But we also have to pass the current value so that at the next step, reduce can evaluate the given function. So we have a tuple as output from the function :

(truth value whether we have seen two adjacent items, current item)

We need to pass an initial value for the tuple (initializer). The truth value would be False initially, and we pass a zero as all elements of the list are positive. (False, 0)

So within the lambda, we need to check if the current item is the same as the previous item (y == x[1]) but if we had already met this condition (x[0]), we need to pass this along.

One of the drawbacks of the fold is that there is no quick break from traversing the list once we find a duplicate. It is possible to raise an exception in lambda and force a termination that way, but I don't know of a clean way to terminate the walk of the complete list.

Monday, August 07, 2017

Python 2.7 scoping bug

Here is a piece of code that does not work on Python 2.7:

 #!/usr/bin/python  
 def img_type(s):  
   return str(s)  
 print (img_type(50))  
 a = [img_dir for (img_dir, img_type) in [("a",1),("b",2)]]  
 print (img_type(20))  

On the second print, it raises a TypeError:

TypeError: 'int' object is not callable

The interpreter is incorrectly identifying the scoped variables img_dir, img_type to be in global scope. Since the function is of the same name, the variable takes precedence. Actually it overwrites the function!

We can see what is happening by looking at the globals().items() and locals().items(). Each is a list of tuples where each tuple contains the variable name and its currently assigned value. Here is a modified program that lists the variables, before and after we define the list comprehension:

#!/usr/bin/python

def img_type(s):
    return str(s)

print (img_type(50))

print ("BEFORE")
print (globals().items())
print (locals().items())


a = [img_dir for (img_dir, img_type) in [("a",1),("b",2)]]

print ("AFTER")
print (globals().items())
print (locals().items())

print (img_type(20))

This outputs:

50
BEFORE
[('img_type', <function img_type at 0x7f740ad5eb18>), ('__builtins__', <module '__builtin__' (built-in)>), ('__file__', './proof1.py'), ('__package__', None), ('__name__', '__main__'), ('__doc__', None)]
[('img_type', <function img_type at 0x7f740ad5eb18>), ('__builtins__', <module '__builtin__' (built-in)>), ('__file__', './proof1.py'), ('__package__', None), ('__name__', '__main__'), ('__doc__', None)]
AFTER
[('img_type', 2), ('a', ['a', 'b']), ('__builtins__', <module '__builtin__' (built-in)>), ('img_dir', 'b'), ('__file__', './proof1.py'), ('__package__', None), ('__name__', '__main__'), ('__doc__', None)]
[('img_type', 2), ('a', ['a', 'b']), ('__builtins__', <module '__builtin__' (built-in)>), ('img_dir', 'b'), ('__file__', './proof1.py'), ('__package__', None), ('__name__', '__main__'), ('__doc__', None)]
Traceback (most recent call last):
  File "./proof1.py", line 19, in <module>
    print (img_type(20))
TypeError: 'int' object is not callable

Notice how after the list comprehension the img_type() function got clobbered by the locally scoped variable by the same name.

This is fixed as of Python 3.2. Here is the output running the second version of the program:

50
BEFORE
dict_items([('__name__', '__main__'), ('__doc__', None), ('__loader__', <_frozen_importlib_external.SourceFileLoader object at 0x7ff93f4b7780>), ('__file__', './proof1.py'), ('__builtins__', <module 'builtins' (built-in)>), ('__spec__', None), ('img_type', <function img_type at 0x7ff93f41c2f0>), ('__package__', None), ('__cached__', None)])
dict_items([('__name__', '__main__'), ('__doc__', None), ('__loader__', <_frozen_importlib_external.SourceFileLoader object at 0x7ff93f4b7780>), ('__file__', './proof1.py'), ('__builtins__', <module 'builtins' (built-in)>), ('__spec__', None), ('img_type', <function img_type at 0x7ff93f41c2f0>), ('__package__', None), ('__cached__', None)])
AFTER
dict_items([('__name__', '__main__'), ('__doc__', None), ('a', ['a', 'b']), ('__loader__', <_frozen_importlib_external.SourceFileLoader object at 0x7ff93f4b7780>), ('__file__', './proof1.py'), ('__builtins__', <module 'builtins' (built-in)>), ('__spec__', None), ('img_type', <function img_type at 0x7ff93f41c2f0>), ('__package__', None), ('__cached__', None)])
dict_items([('__name__', '__main__'), ('__doc__', None), ('a', ['a', 'b']), ('__loader__', <_frozen_importlib_external.SourceFileLoader object at 0x7ff93f4b7780>), ('__file__', './proof1.py'), ('__builtins__', <module 'builtins' (built-in)>), ('__spec__', None), ('img_type', <function img_type at 0x7ff93f41c2f0>), ('__package__', None), ('__cached__', None)])
20