added pruning to trie

This commit is contained in:
Leyla Becker 2025-07-19 20:14:21 -05:00
parent f93d24dc39
commit a620903c10
2 changed files with 48 additions and 5 deletions

View file

@ -2,7 +2,7 @@
// Import the module and reference it with the alias vscode in your code below
import * as vscode from 'vscode';
import { Ollama } from 'ollama/browser';
import { flattenTrie, trieInsert, trieLookup, TrieNode } from './trie';
import { flattenTrie, trieInsert, trieLookup, TrieNode, triePrune } from './trie';
const MODEL = 'deepseek-coder:6.7b';
@ -14,6 +14,7 @@ const SUFFIX_END = '</suffixStart>';
const MAX_TOKENS = 50;
const GENERATION_TIMEOUT = 200;
const TRIE_PRUNE_TIMEOUT = 10000;
const HOST = undefined;
@ -70,9 +71,8 @@ const getSuffix = (document: vscode.TextDocument, position: vscode.Position) =>
};
let trieRoot: TrieNode = {
isLeaf: false,
isLeaf: true,
value: '',
children: {},
};
const trieRootInsert = (text: string) => {
@ -83,6 +83,10 @@ const trieRootLookup = (text: string) => {
return trieLookup(trieRoot, text);
};
const trieRootPrune = (text: string) => {
return triePrune(trieRoot, text);
};
const tokenProvider = async (
document: vscode.TextDocument,
position: vscode.Position,
@ -102,6 +106,10 @@ const tokenProvider = async (
return flattenTrie(result);
}
if (token.isCancellationRequested) {
return [];
}
const response = await ollama.generate({
model: MODEL,
prompt,
@ -114,8 +122,15 @@ const tokenProvider = async (
},
});
const pruneTimeout = setTimeout(() => {
trieRootPrune(prompt);
}, TRIE_PRUNE_TIMEOUT);
token.onCancellationRequested(() => {
clearTimeout(pruneTimeout);
try {
response.abort();
} catch {}
});
const resultBuffer: string[] = await new Promise(async (resolve, reject) => {

View file

@ -96,6 +96,12 @@ export const trieInsert = (node: TrieNode, text: string): TrieNode => {
return node;
};
/**
* Gets a new trie that is the node trie after the text
* @param node
* @param text
* @returns
*/
export const trieLookup = (node: TrieNode, text: string): TrieNode | null => {
for (let index = 0; index < node.value.length; index++) {
// If our node still has text left but we have no more search query return a node starting at where we ran out of characters
@ -133,7 +139,6 @@ export const trieLookup = (node: TrieNode, text: string): TrieNode | null => {
return child === undefined ? null : trieLookup(child, childText);
};
export const flattenTrie = (node: TrieNode): string[] => {
if (node.isLeaf) {
return [
@ -145,3 +150,26 @@ export const flattenTrie = (node: TrieNode): string[] => {
return flattenTrie(child).map((value) => node.value + key + value);
});
};
export const triePrune = (node: TrieNode, text: string): TrieNode => {
const value = trieLookup(node, text);
if (value === null) {
return {
isLeaf: true,
value: text,
};
}
if (value.isLeaf) {
return {
isLeaf: true,
value: text + value.value,
};
}
return {
isLeaf: false,
value: text + value.value,
children: value.children,
};
};