used trie to store auto complete results

This commit is contained in:
Leyla Becker 2025-07-19 19:38:41 -05:00
parent ffd384fbe4
commit 55c72ff98c
2 changed files with 54 additions and 38 deletions

View file

@ -2,15 +2,12 @@
// Import the module and reference it with the alias vscode in your code below // Import the module and reference it with the alias vscode in your code below
import * as vscode from 'vscode'; import * as vscode from 'vscode';
import { Ollama } from 'ollama/browser'; import { Ollama } from 'ollama/browser';
import { trieInsert, trieLookup, TrieNode } from './trie'; import { flattenTrie, trieInsert, trieLookup, TrieNode } from './trie';
const MODEL = 'deepseek-coder:6.7b'; const MODEL = 'deepseek-coder:6.7b';
const PREFIX_START = '<prefixStart>'; const PREFIX_START = '<prefixStart>';
const PREFIX_END = '<prefixEnd>'; const PREFIX_ENDS = ['<prefixEnd>', '</prefixEnd>'];
const SUFFIX_START = '<suffixStart>';
const SUFFIX_END = '<suffixEnd>';
const MAX_TOKENS = 50; const MAX_TOKENS = 50;
const GENERATION_TIMEOUT = 200; const GENERATION_TIMEOUT = 200;
@ -74,12 +71,20 @@ const tokenProvider = async (
_token: vscode.CancellationToken, _token: vscode.CancellationToken,
) => { ) => {
const prefix = document.getText(new vscode.Range(0, 0, position.line, position.character)); const prefix = document.getText(new vscode.Range(0, 0, position.line, position.character));
const modelSupportsSuffix = await getModelSupportsSuffix(MODEL); const modelSupportsSuffix = await getModelSupportsSuffix(MODEL);
const prompt = getPrompt(document, position, prefix); const prompt = getPrompt(document, position, prefix);
const suffix = modelSupportsSuffix ? getSuffix(document, position) : undefined; const suffix = modelSupportsSuffix ? getSuffix(document, position) : undefined;
console.log(JSON.stringify(trieRoot));
const result = trieRootLookup(prefix);
if (result !== null) {
return flattenTrie(result);
}
const response = await ollama.generate({ const response = await ollama.generate({
model: MODEL, model: MODEL,
prompt, prompt,
@ -88,11 +93,33 @@ const tokenProvider = async (
stream: true, stream: true,
options: { options: {
num_predict: MAX_TOKENS, num_predict: MAX_TOKENS,
stop: [PREFIX_END] stop: PREFIX_ENDS,
}, },
}); });
return response; const resultBuffer: string[] = await new Promise(async (resolve, reject) => {
const buffer: string[] = [];
const timeout = setTimeout(() => {
resolve(buffer);
}, GENERATION_TIMEOUT);
try {
for await (const part of response) {
buffer.push(part.response);
trieRootInsert(prefix + buffer.join(''));
}
resolve(buffer);
} catch (err) {
reject(err);
} finally {
response.abort();
clearTimeout(timeout);
};
});
return [
resultBuffer.join('')
];
}; };
export const activate = (context: vscode.ExtensionContext) => { export const activate = (context: vscode.ExtensionContext) => {
@ -101,36 +128,12 @@ export const activate = (context: vscode.ExtensionContext) => {
const provider: vscode.InlineCompletionItemProvider = { const provider: vscode.InlineCompletionItemProvider = {
async provideInlineCompletionItems(document, position, context, token) { async provideInlineCompletionItems(document, position, context, token) {
try { try {
const response = await tokenProvider(document, position, context, token); const completions = await tokenProvider(document, position, context, token);
const resultBuffer: string[] = await new Promise(async (resolve, reject) => { return completions.map((text) => ({
const buffer: string[] = []; insertText: text,
const timeout = setTimeout(() => { range: new vscode.Range(position, position),
resolve(buffer); }));
}, GENERATION_TIMEOUT);
try {
for await (const part of response) {
// process.stdout.write(part.response);
buffer.push(part.response);
}
resolve(buffer);
} catch (err) {
reject(err);
} finally {
response.abort();
clearTimeout(timeout);
};
});
const text = resultBuffer.join('');
return [
{
insertText: text,
range: new vscode.Range(position, position),
}
];
} catch (err) { } catch (err) {
console.log(err); console.log(err);
} }
@ -143,4 +146,4 @@ export const activate = (context: vscode.ExtensionContext) => {
}; };
// This method is called when your extension is deactivated // This method is called when your extension is deactivated
export function deactivate() {} export function deactivate() { }

View file

@ -132,3 +132,16 @@ export const trieLookup = (node: TrieNode, text: string): TrieNode | null => {
const child = node.children[childKey]; const child = node.children[childKey];
return child === undefined ? null : trieLookup(child, childText); return child === undefined ? null : trieLookup(child, childText);
}; };
export const flattenTrie = (node: TrieNode): string[] => {
if (node.isLeaf) {
return [
node.value
];
}
return Object.entries(node.children).flatMap(([key, child]) => {
return flattenTrie(child).map((value) => node.value + key + value);
});
};