From 55c72ff98ca1cb348dd9131fd577c01251a225a1 Mon Sep 17 00:00:00 2001 From: Leyla Becker Date: Sat, 19 Jul 2025 19:38:41 -0500 Subject: [PATCH] used trie to store auto complete results --- src/extension.ts | 79 +++++++++++++++++++++++++----------------------- src/trie.ts | 13 ++++++++ 2 files changed, 54 insertions(+), 38 deletions(-) diff --git a/src/extension.ts b/src/extension.ts index f66ff54..02cfdb1 100644 --- a/src/extension.ts +++ b/src/extension.ts @@ -2,15 +2,12 @@ // Import the module and reference it with the alias vscode in your code below import * as vscode from 'vscode'; import { Ollama } from 'ollama/browser'; -import { trieInsert, trieLookup, TrieNode } from './trie'; +import { flattenTrie, trieInsert, trieLookup, TrieNode } from './trie'; const MODEL = 'deepseek-coder:6.7b'; const PREFIX_START = ''; -const PREFIX_END = ''; - -const SUFFIX_START = ''; -const SUFFIX_END = ''; +const PREFIX_ENDS = ['', '']; const MAX_TOKENS = 50; const GENERATION_TIMEOUT = 200; @@ -74,12 +71,20 @@ const tokenProvider = async ( _token: vscode.CancellationToken, ) => { const prefix = document.getText(new vscode.Range(0, 0, position.line, position.character)); - + const modelSupportsSuffix = await getModelSupportsSuffix(MODEL); const prompt = getPrompt(document, position, prefix); const suffix = modelSupportsSuffix ? getSuffix(document, position) : undefined; + console.log(JSON.stringify(trieRoot)); + + const result = trieRootLookup(prefix); + + if (result !== null) { + return flattenTrie(result); + } + const response = await ollama.generate({ model: MODEL, prompt, @@ -88,11 +93,33 @@ const tokenProvider = async ( stream: true, options: { num_predict: MAX_TOKENS, - stop: [PREFIX_END] + stop: PREFIX_ENDS, }, }); - return response; + const resultBuffer: string[] = await new Promise(async (resolve, reject) => { + const buffer: string[] = []; + const timeout = setTimeout(() => { + resolve(buffer); + }, GENERATION_TIMEOUT); + + try { + for await (const part of response) { + buffer.push(part.response); + trieRootInsert(prefix + buffer.join('')); + } + resolve(buffer); + } catch (err) { + reject(err); + } finally { + response.abort(); + clearTimeout(timeout); + }; + }); + + return [ + resultBuffer.join('') + ]; }; export const activate = (context: vscode.ExtensionContext) => { @@ -101,36 +128,12 @@ export const activate = (context: vscode.ExtensionContext) => { const provider: vscode.InlineCompletionItemProvider = { async provideInlineCompletionItems(document, position, context, token) { try { - const response = await tokenProvider(document, position, context, token); + const completions = await tokenProvider(document, position, context, token); - const resultBuffer: string[] = await new Promise(async (resolve, reject) => { - const buffer: string[] = []; - const timeout = setTimeout(() => { - resolve(buffer); - }, GENERATION_TIMEOUT); - - try { - for await (const part of response) { - // process.stdout.write(part.response); - buffer.push(part.response); - } - resolve(buffer); - } catch (err) { - reject(err); - } finally { - response.abort(); - clearTimeout(timeout); - }; - }); - - const text = resultBuffer.join(''); - - return [ - { - insertText: text, - range: new vscode.Range(position, position), - } - ]; + return completions.map((text) => ({ + insertText: text, + range: new vscode.Range(position, position), + })); } catch (err) { console.log(err); } @@ -143,4 +146,4 @@ export const activate = (context: vscode.ExtensionContext) => { }; // This method is called when your extension is deactivated -export function deactivate() {} +export function deactivate() { } diff --git a/src/trie.ts b/src/trie.ts index 3834544..678df44 100644 --- a/src/trie.ts +++ b/src/trie.ts @@ -132,3 +132,16 @@ export const trieLookup = (node: TrieNode, text: string): TrieNode | null => { const child = node.children[childKey]; return child === undefined ? null : trieLookup(child, childText); }; + + +export const flattenTrie = (node: TrieNode): string[] => { + if (node.isLeaf) { + return [ + node.value + ]; + } + + return Object.entries(node.children).flatMap(([key, child]) => { + return flattenTrie(child).map((value) => node.value + key + value); + }); +};