Merge Sort Of Lists

| 分类 Kernel  | 标签 kernel  sort  list 

list_sort.c from linux kernel provides a function to sort double linked list, it says:

1: /**
2:  * This function implements "merge sort", which has O(nlog(n))
3:  * complexity.
4:  */

But if you look at the implementation, you may be surprised: the implementation differs from normal ordinary "merge sort"!

 1: /*
 2:  * Returns a list organized in an intermediate format suited
 3:  * to chaining of merge() calls: null-terminated, no reserved or
 4:  * sentinel head node, "prev" links not maintained.
 5:  */
 6: static struct list_head *merge(void *priv,
 7:                                int (*cmp)(void *priv, struct list_head *a,
 8:                                           struct list_head *b),
 9:                                struct list_head *a, struct list_head *b)
10: {
11:     struct list_head head, *tail = &head;
12: 
13:     while (a && b) {
14:         /* if equal, take 'a' -- important for sort stability */
15:         if ((*cmp)(priv, a, b) <= 0) {
16:             tail->next = a;
17:             a = a->next;
18:         } else {
19:             tail->next = b;
20:             b = b->next;
21:         }
22:         tail = tail->next;
23:     }
24:     tail->next = a?:b;
25:     return head.next;
26: }

Function merge is so clear and clean compared normally ones, for example this one from wiki:

 1: typedef struct _slist
 2: {
 3:     struct _slist* next;
 4:     int   val;
 5: } slist;
 6: 
 7: 
 8: slist* merge(slist*a, slist* b)
 9: {
10:     sts ++;
11:     slist* head = NULL;
12:     slist* tail = NULL;
13:     slist* p = a;
14:     slist* q = b;
15:     while (p && q) {
16:         if (p->val < q->val) {
17:             if (!head) {
18:                 head = p;
19:                 tail = p;
20:                 p = p->next;
21:             }
22:             else {
23:                 tail->next = p;
24:                 tail = p;
25:                 p = p->next;
26:             }
27:         }
28:         else {
29:             if (!head) {
30:                 head = q;
31:                 tail = q;
32:                 q = q->next;
33:             }
34:             else {
35:                 tail->next = q;
36:                 tail = q;
37:                 q = q->next;
38:             }
39:         }
40:     }
41: 
42:     if (p)
43:         tail->next = p;
44:     if (q)
45:         tail->next = q;
46: 
47:     return head;
48: }

And the sort function:

 1: /**
 2:  * list_sort - sort a list
 3:  * @priv: private data, opaque to list_sort(), passed to @cmp
 4:  * @head: the list to sort
 5:  * @cmp: the elements comparison function
 6:  *
 7:  * This function implements "merge sort", which has O(nlog(n))
 8:  * complexity.
 9:  *
10:  * The comparison function @cmp must return a negative value if @a
11:  * should sort before @b, and a positive value if @a should sort after
12:  * @b. If @a and @b are equivalent, and their original relative
13:  * ordering is to be preserved, @cmp must return 0.
14:  */
15: void list_sort(void *priv, struct list_head *head,
16:                int (*cmp)(void *priv, struct list_head *a,
17:                           struct list_head *b))
18: {
19:     struct list_head *part[MAX_LIST_LENGTH_BITS+1]; /* sorted partial lists
20:                                                        -- last slot is a sentinel */
21:     int lev;  /* index into part[] */
22:     int max_lev = 0;
23:     struct list_head *list;
24: 
25:     if (list_empty(head))
26:         return;
27: 
28:     memset(part, 0, sizeof(part));
29: 
30:     head->prev->next = NULL;
31:     list = head->next;
32: 
33:     while (list) {
34:         struct list_head *cur = list;
35:         list = list->next;
36:         cur->next = NULL;
37: 
38:         for (lev = 0; part[lev]; lev++) {
39:             cur = merge(priv, cmp, part[lev], cur);
40:             part[lev] = NULL;
41:         }
42:         if (lev > max_lev) {
43:             if (unlikely(lev >= ARRAY_SIZE(part)-1)) {
44:                 printk_once(KERN_DEBUG "list passed to"
45:                             " list_sort() too long for"
46:                             " efficiency\n");
47:                 lev--;
48:             }
49:             max_lev = lev;
50:         }
51:         part[lev] = cur;
52:     }
53: 
54:     for (lev = 0; lev < max_lev; lev++)
55:         if (part[lev])
56:             list = merge(priv, cmp, part[lev], list);
57: 
58:     merge_and_restore_back_links(priv, cmp, head, part[max_lev], list);
59: }

But wait, is this really merge sort ? A merge sort should normally consist of two parts: sort and merge. For example, this one:

 1: slist* get_middle_list(slist* h)
 2: {
 3:     sts ++;
 4:     slist* p = h;
 5:     slist* q = h;
 6:     slist* k = q;
 7:     while (p) {
 8:         p = p->next;
 9:         if (p) {
10:             p = p->next;
11:         }
12:         k = q;
13:         q = q->next;
14:     }
15:     return k;
16: }
17: 
18: slist* merge(slist*a, slist* b)
19: {
20:     sts ++;
21:     slist* head = NULL;
22:     slist* tail = NULL;
23:     slist* p = a;
24:     slist* q = b;
25:     while (p && q) {
26:         if (p->val < q->val) {
27:             if (!head) {
28:                 head = p;
29:                 tail = p;
30:                 p = p->next;
31:             }
32:             else {
33:                 tail->next = p;
34:                 tail = p;
35:                 p = p->next;
36:             }
37:         }
38:         else {
39:             if (!head) {
40:                 head = q;
41:                 tail = q;
42:                 q = q->next;
43:             }
44:             else {
45:                 tail->next = q;
46:                 tail = q;
47:                 q = q->next;
48:             }
49:         }
50:     }
51: 
52:     if (p)
53:         tail->next = p;
54:     if (q)
55:         tail->next = q;
56: 
57:     return head;
58: }
59: 
60: 
61: slist* merge_sort_1(slist* h)
62: {
63:     sts ++;
64:     if (!h->next)
65:         return h;
66: 
67:     slist* s = get_middle_list(h);
68:     slist* nh = s->next;
69:     s->next = NULL;
70:     return merge(merge_sort_1(h), merge_sort_1(nh));
71: }

Here, the kernel implementation does not use helper function to "divide" list, but uses a table to store partially sorted list. The following figure show the parts will look like after processed the first 127 (2^6-1) elements:

list_sort.png

And after processed the 128th elements, the first 6 places of parts are empty, and the 7th element of parts array are filled with a sorted list of the first 128 nodes:

lsort_2.png

I don't know how guys figured out this subtle solution. It is really more efficient and elegant than ordinary ones.

I wrote a similar one with same algorithm to test it, to process 1M elements, it uses 411 ms, with 570641 times of function calls, while the C implementation in wiki takes more than 502 ms,with 7452856 function calls.

Compare solution provided by wiki and the kernel one, I guess both function calls and 'if-else' statement can make performance worse. But normally, that should be acceptable.

I put my test codes at here, you can download and try them yourself.


上一篇     下一篇