add tree edit distance implementation and utility functions

This commit is contained in:
Akos Horvath 2024-02-26 16:30:03 +01:00
parent 0fefb31535
commit 11636b2754
2 changed files with 241 additions and 4 deletions

View File

@ -114,6 +114,15 @@ void wm_treenode_swap(TreeNode *root, TreeNode *node1, TreeNode* node2)
node2->client = tmp;
}
void wm_treenode_add_child(TreeNode *node, TreeNode *child)
{
assert(node);
assert(child);
wm_nodearray_push(node->children, child);
child->parent = node;
}
void wm_node_type_to_str(NodeType type, char *buf, size_t bufsize)
{
switch (type) {
@ -326,6 +335,30 @@ NodeArray* wm_treenode_all_lmds(TreeNode *node)
return ret;
}
UIntArray* wm_treenode_all_lmds_index(TreeNode *node)
{
assert(node);
NodeArray *postorder = wm_postorder_traversal(node);
UIntArray *ret = wm_uintarray_new();
for (size_t i = 0; i < postorder->size; i++) {
TreeNode *lmd = wm_treenode_lmd(postorder->nodes[i]);
uintptr_t index = UINT_MAX;
for (index = 0; index < postorder->size; index++) {
if (postorder->nodes[index]->id == lmd->id)
break;
}
assert(index != UINT_MAX);
wm_uintarray_push(ret, index);
}
wm_nodearray_free(postorder);
return ret;
}
NodeArray* wm_postorder_traversal(TreeNode *tree)
{
assert(tree);
@ -385,12 +418,30 @@ NodeArray* wm_postorder_traversal(TreeNode *tree)
return ret;
}
bool wm_is_treenode_keyroot(TreeNode *node)
bool wm_is_treenode_keyroot(TreeNode *node, NodeArray *postorder)
{
assert(node);
assert(postorder);
if (!node->parent) return true;
return wm_treenode_lmd(node) != wm_treenode_lmd(node->parent);
TreeNode *node_lmd = wm_treenode_lmd(node);
size_t node_index = SIZE_MAX;
for (size_t i = 0; i < postorder->size; i++) {
if (node->id == postorder->nodes[i]->id) {
node_index = i;
break;
}
}
assert(node_index != SIZE_MAX);
for (size_t i = node_index + 1; i < postorder->size; i++) {
if (node_lmd->id == wm_treenode_lmd(postorder->nodes[i])->id)
return false;
}
return true;
}
NodeArray* wm_treenode_all_keyroots(TreeNode *node)
@ -402,7 +453,7 @@ NodeArray* wm_treenode_all_keyroots(TreeNode *node)
for (size_t i = 0; i < postorder->size; i++) {
TreeNode *node = postorder->nodes[i];
if (wm_is_treenode_keyroot(node))
if (wm_is_treenode_keyroot(node, postorder))
wm_nodearray_push(ret, node);
}
@ -410,6 +461,171 @@ NodeArray* wm_treenode_all_keyroots(TreeNode *node)
return ret;
}
UIntArray* wm_treenode_all_keyroots_index(TreeNode *node)
{
assert(node);
NodeArray *postorder = wm_postorder_traversal(node);
UIntArray *ret = wm_uintarray_new();
for (size_t i = 0; i < postorder->size; i++) {
if (wm_is_treenode_keyroot(postorder->nodes[i], postorder))
wm_uintarray_push(ret, i);
}
wm_nodearray_free(postorder);
return ret;
}
bool wm_nodes_are_equal(TreeNode *node1, TreeNode *node2)
{
if (node1->type != node2->type)
return false;
return true;
}
// https://github.com/timtadh/zhang-shasha
void wm_treedist(TreeNode *tree1, TreeNode *tree2, size_t i, size_t j,
UIntArray *dists, TreeEditDistanceCosts costs)
{
assert(tree1);
assert(tree2);
NodeArray *tree1_postorder = wm_postorder_traversal(tree1);
NodeArray *tree2_postorder = wm_postorder_traversal(tree2);
UIntArray *tree1_lmds = wm_treenode_all_lmds_index(tree1);
UIntArray *tree2_lmds = wm_treenode_all_lmds_index(tree2);
size_t m = i - wm_uintarray_at(tree1_lmds, i) + 2;
size_t n = j - wm_uintarray_at(tree2_lmds, j) + 2;
UIntArray *fd = wm_uintarray_new();
_Static_assert(sizeof(uint64_t) == sizeof(uint64_t*), "static assert failed");
for (size_t k = 0; k < m; k++) {
wm_uintarray_push(fd, (uint64_t)calloc(n, sizeof(uint64_t)));
}
size_t ioff = wm_uintarray_at(tree1_lmds, i) - 1;
size_t joff = wm_uintarray_at(tree2_lmds, j) - 1;
for (size_t x = 1; x < m; x++) {
TreeNode *node = tree1_postorder->nodes[x + ioff];
((uint64_t**)fd->elements)[x][0] = ((size_t)wm_uintarray_2d_index(fd,
x - 1, 0) + (size_t)costs.remove_cost);
}
for (size_t y = 1; y < n; y++) {
TreeNode *node = tree2_postorder->nodes[y + joff];
((uint64_t**)fd->elements)[0][y] = ((size_t)wm_uintarray_2d_index(fd,
0, y - 1) + (size_t)costs.insert_cost);
}
for (size_t x = 1; x < m; x++) {
for (size_t y = 1; y < n; y++) {
TreeNode *node1 = tree1_postorder->nodes[x + ioff];
TreeNode *node2 = tree2_postorder->nodes[y + joff];
if (wm_uintarray_at(tree1_lmds, i) == tree1_lmds->elements[x + ioff] &&
wm_uintarray_at(tree2_lmds, j) == tree2_lmds->elements[y + joff]) {
size_t _costs[3] = {
wm_uintarray_2d_index(fd, x - 1, y) + costs.remove_cost, //(node1)
wm_uintarray_2d_index(fd, x, y - 1) + costs.insert_cost, //(node2)
wm_uintarray_2d_index(fd, x - 1, y - 1) + (*costs.update_cost_function)(node1, node2),
};
size_t min = SIZE_MAX;
for (size_t k = 0; k < 3; k++) {
if (_costs[k] < min) {
min = _costs[k];
}
}
assert(min != SIZE_MAX);
((uint64_t**)fd->elements)[x][y] = min;
((uint64_t**)dists->elements)[x + ioff][y + joff] = min;
} else {
size_t p = tree1_lmds->elements[x + ioff] - 1 - ioff;
size_t q = tree2_lmds->elements[y + joff] - 1 - joff;
size_t _costs[3] = {
(size_t)wm_uintarray_2d_index(fd, x - 1, y) + costs.remove_cost, // (node1)
(size_t)wm_uintarray_2d_index(fd, x, y - 1) + costs.insert_cost, // (node2)
(size_t)wm_uintarray_2d_index(fd, p, q) +
(size_t)wm_uintarray_2d_index(dists, x + ioff, y + joff),
};
size_t min = SIZE_MAX;
for (size_t k = 0; k < 3; k++) {
if (_costs[k] < min) {
min = _costs[k];
}
}
assert(min != SIZE_MAX);
((uint64_t**)fd->elements)[x][y] = min;
}
}
}
for (size_t k = 0; k < m; k++) {
free((void*)wm_uintarray_at(fd, k));
}
wm_nodearray_free(tree1_postorder);
wm_nodearray_free(tree2_postorder);
wm_uintarray_free(tree1_lmds);
wm_uintarray_free(tree2_lmds);
wm_uintarray_free(fd);
}
// https://github.com/timtadh/zhang-shasha
size_t wm_tree_edit_distance(TreeNode *tree1, TreeNode *tree2, TreeEditDistanceCosts costs)
{
assert(tree1);
assert(tree2);
NodeArray *tree1_postorder = wm_postorder_traversal(tree1);
NodeArray *tree2_postorder = wm_postorder_traversal(tree2);
UIntArray *tree1_keyroot_indexes = wm_treenode_all_keyroots_index(tree1);
UIntArray *tree2_keyroot_indexes = wm_treenode_all_keyroots_index(tree2);
UIntArray *dists = wm_uintarray_new();
_Static_assert(sizeof(uint64_t) == sizeof(uint64_t*), "static assert failed");
for (size_t k = 0; k < tree1_postorder->size; k++) {
wm_uintarray_push(dists, (uint64_t)calloc(tree2_postorder->size, sizeof(uint64_t)));
}
for (size_t p = 0; p < tree1_keyroot_indexes->size; p++) {
for (size_t q = 0; q < tree2_keyroot_indexes->size; q++) {
size_t i = wm_uintarray_at(tree1_keyroot_indexes, p);
size_t j = wm_uintarray_at(tree2_keyroot_indexes, q);
wm_treedist(tree1, tree2, i, j, dists, costs);
}
}
size_t ret = wm_uintarray_2d_index(dists, tree1_postorder->size - 1, tree2_postorder->size - 1);
for (size_t k = 0; k < tree1_postorder->size; k++) {
free((void*)wm_uintarray_at(dists, k));
}
wm_nodearray_free(tree1_postorder);
wm_nodearray_free(tree2_postorder);
wm_uintarray_free(tree1_keyroot_indexes);
wm_uintarray_free(tree2_keyroot_indexes);
wm_uintarray_free(dists);
return ret;
}
void wm_treenode_print(TreeNode *node)
{
char *str = wm_treenode_to_str(node);

View File

@ -61,13 +61,23 @@ typedef struct {
char *str;
} WmWorkspaceToStrRet;
typedef uint64_t (*UpdateCostFunction)(TreeNode*, TreeNode*);
typedef struct {
uint64_t insert_cost;
uint64_t remove_cost;
UpdateCostFunction update_cost_function;
} TreeEditDistanceCosts;
NodeArray* wm_nodearray_new();
void wm_nodearray_set(NodeArray *arr, size_t index, uint64_t value);
void wm_nodearray_push(NodeArray *arr, TreeNode *node);
bool wm_nodearray_pop(NodeArray *arr, TreeNode **ret);
bool wm_nodearray_pop_front(NodeArray *arr, TreeNode **ret);
void wm_nodearray_clear(NodeArray *arr);
bool wm_nodearray_remove(NodeArray *arr, size_t index);
void wm_nodearray_free(NodeArray *arr);
TreeNode* wm_nodearray_at(NodeArray *arr, size_t index);
TreeNode* wm_treenode_new(NodeType type, TreeNode *parent);
void wm_treenode_free(TreeNode *node);
@ -75,9 +85,11 @@ bool wm_treenode_is_empty(TreeNode *node);
void wm_treenode_split_space(TreeNode *node, Rect *ret1, Rect *ret2);
void wm_treenode_recalculate_space(TreeNode *node);
void wm_treenode_swap(TreeNode *root, TreeNode *node1, TreeNode* node2);
void wm_treenode_add_child(TreeNode *node, TreeNode *child);
TreeNode* wm_treenode_remove_client(Wm *wm, TreeNode *root, Client *client);
void wm_treenode_remove_node(Wm *wm, TreeNode *root, TreeNode *node);
int wm_get_node_index(TreeNode *parent, unsigned int node_id);
void wm_tree_to_DOT(TreeNode *root, const char *filename);
void wm_node_type_to_str(NodeType type, char *buf, size_t bufsize);
UIntArray* wm_nonempty_workspaces_to_strptrarray(Wm *wm);
@ -104,8 +116,17 @@ void wm_log_state(Wm *wm, const char *prefixstr, const char* logfile);
TreeNode* wm_treenode_lmd(TreeNode *node);
NodeArray* wm_treenode_all_lmds(TreeNode *node);
UIntArray* wm_treenode_all_lmds_index(TreeNode *node);
NodeArray* wm_postorder_traversal(TreeNode *tree);
bool wm_is_treenode_keyroot(TreeNode *node);
bool wm_is_treenode_keyroot(TreeNode *node, NodeArray *postorder);
NodeArray* wm_treenode_all_keyroots(TreeNode *node);
UIntArray* wm_treenode_all_keyroots_index(TreeNode *node);
bool wm_nodes_are_equal(TreeNode *node1, TreeNode *node2);
void wm_treedist(TreeNode *tree1, TreeNode *tree2, size_t i, size_t j,
UIntArray *dists, TreeEditDistanceCosts costs);
size_t wm_tree_edit_distance(TreeNode *tree1, TreeNode *tree2, TreeEditDistanceCosts costs);
#endif