diff --git a/src/extension.ts b/src/extension.ts index 31e681e..3f3133f 100644 --- a/src/extension.ts +++ b/src/extension.ts @@ -2,7 +2,7 @@ // Import the module and reference it with the alias vscode in your code below import * as vscode from 'vscode'; import { Ollama } from 'ollama/browser'; -import { flattenTrie, trieInsert, trieLookup, TrieNode } from './trie'; +import { flattenTrie, trieInsert, trieLookup, TrieNode, triePrune } from './trie'; const MODEL = 'deepseek-coder:6.7b'; @@ -14,6 +14,7 @@ const SUFFIX_END = ''; const MAX_TOKENS = 50; const GENERATION_TIMEOUT = 200; +const TRIE_PRUNE_TIMEOUT = 10000; const HOST = undefined; @@ -70,9 +71,8 @@ const getSuffix = (document: vscode.TextDocument, position: vscode.Position) => }; let trieRoot: TrieNode = { - isLeaf: false, + isLeaf: true, value: '', - children: {}, }; const trieRootInsert = (text: string) => { @@ -83,6 +83,10 @@ const trieRootLookup = (text: string) => { return trieLookup(trieRoot, text); }; +const trieRootPrune = (text: string) => { + return triePrune(trieRoot, text); +}; + const tokenProvider = async ( document: vscode.TextDocument, position: vscode.Position, @@ -102,6 +106,10 @@ const tokenProvider = async ( return flattenTrie(result); } + if (token.isCancellationRequested) { + return []; + } + const response = await ollama.generate({ model: MODEL, prompt, @@ -114,8 +122,15 @@ const tokenProvider = async ( }, }); + const pruneTimeout = setTimeout(() => { + trieRootPrune(prompt); + }, TRIE_PRUNE_TIMEOUT); + token.onCancellationRequested(() => { - response.abort(); + clearTimeout(pruneTimeout); + try { + response.abort(); + } catch {} }); const resultBuffer: string[] = await new Promise(async (resolve, reject) => { diff --git a/src/trie.ts b/src/trie.ts index 678df44..3ed2ed9 100644 --- a/src/trie.ts +++ b/src/trie.ts @@ -96,6 +96,12 @@ export const trieInsert = (node: TrieNode, text: string): TrieNode => { return node; }; +/** + * Gets a new trie that is the node trie after the text + * @param node + * @param text + * @returns + */ export const trieLookup = (node: TrieNode, text: string): TrieNode | null => { for (let index = 0; index < node.value.length; index++) { // If our node still has text left but we have no more search query return a node starting at where we ran out of characters @@ -133,7 +139,6 @@ export const trieLookup = (node: TrieNode, text: string): TrieNode | null => { return child === undefined ? null : trieLookup(child, childText); }; - export const flattenTrie = (node: TrieNode): string[] => { if (node.isLeaf) { return [ @@ -145,3 +150,26 @@ export const flattenTrie = (node: TrieNode): string[] => { return flattenTrie(child).map((value) => node.value + key + value); }); }; + +export const triePrune = (node: TrieNode, text: string): TrieNode => { + const value = trieLookup(node, text); + if (value === null) { + return { + isLeaf: true, + value: text, + }; + } + + if (value.isLeaf) { + return { + isLeaf: true, + value: text + value.value, + }; + } + + return { + isLeaf: false, + value: text + value.value, + children: value.children, + }; +};