used trie to store auto complete results
This commit is contained in:
parent
ffd384fbe4
commit
55c72ff98c
2 changed files with 54 additions and 38 deletions
|
@ -2,15 +2,12 @@
|
|||
// Import the module and reference it with the alias vscode in your code below
|
||||
import * as vscode from 'vscode';
|
||||
import { Ollama } from 'ollama/browser';
|
||||
import { trieInsert, trieLookup, TrieNode } from './trie';
|
||||
import { flattenTrie, trieInsert, trieLookup, TrieNode } from './trie';
|
||||
|
||||
const MODEL = 'deepseek-coder:6.7b';
|
||||
|
||||
const PREFIX_START = '<prefixStart>';
|
||||
const PREFIX_END = '<prefixEnd>';
|
||||
|
||||
const SUFFIX_START = '<suffixStart>';
|
||||
const SUFFIX_END = '<suffixEnd>';
|
||||
const PREFIX_ENDS = ['<prefixEnd>', '</prefixEnd>'];
|
||||
|
||||
const MAX_TOKENS = 50;
|
||||
const GENERATION_TIMEOUT = 200;
|
||||
|
@ -74,12 +71,20 @@ const tokenProvider = async (
|
|||
_token: vscode.CancellationToken,
|
||||
) => {
|
||||
const prefix = document.getText(new vscode.Range(0, 0, position.line, position.character));
|
||||
|
||||
|
||||
const modelSupportsSuffix = await getModelSupportsSuffix(MODEL);
|
||||
const prompt = getPrompt(document, position, prefix);
|
||||
|
||||
const suffix = modelSupportsSuffix ? getSuffix(document, position) : undefined;
|
||||
|
||||
console.log(JSON.stringify(trieRoot));
|
||||
|
||||
const result = trieRootLookup(prefix);
|
||||
|
||||
if (result !== null) {
|
||||
return flattenTrie(result);
|
||||
}
|
||||
|
||||
const response = await ollama.generate({
|
||||
model: MODEL,
|
||||
prompt,
|
||||
|
@ -88,11 +93,33 @@ const tokenProvider = async (
|
|||
stream: true,
|
||||
options: {
|
||||
num_predict: MAX_TOKENS,
|
||||
stop: [PREFIX_END]
|
||||
stop: PREFIX_ENDS,
|
||||
},
|
||||
});
|
||||
|
||||
return response;
|
||||
const resultBuffer: string[] = await new Promise(async (resolve, reject) => {
|
||||
const buffer: string[] = [];
|
||||
const timeout = setTimeout(() => {
|
||||
resolve(buffer);
|
||||
}, GENERATION_TIMEOUT);
|
||||
|
||||
try {
|
||||
for await (const part of response) {
|
||||
buffer.push(part.response);
|
||||
trieRootInsert(prefix + buffer.join(''));
|
||||
}
|
||||
resolve(buffer);
|
||||
} catch (err) {
|
||||
reject(err);
|
||||
} finally {
|
||||
response.abort();
|
||||
clearTimeout(timeout);
|
||||
};
|
||||
});
|
||||
|
||||
return [
|
||||
resultBuffer.join('')
|
||||
];
|
||||
};
|
||||
|
||||
export const activate = (context: vscode.ExtensionContext) => {
|
||||
|
@ -101,36 +128,12 @@ export const activate = (context: vscode.ExtensionContext) => {
|
|||
const provider: vscode.InlineCompletionItemProvider = {
|
||||
async provideInlineCompletionItems(document, position, context, token) {
|
||||
try {
|
||||
const response = await tokenProvider(document, position, context, token);
|
||||
const completions = await tokenProvider(document, position, context, token);
|
||||
|
||||
const resultBuffer: string[] = await new Promise(async (resolve, reject) => {
|
||||
const buffer: string[] = [];
|
||||
const timeout = setTimeout(() => {
|
||||
resolve(buffer);
|
||||
}, GENERATION_TIMEOUT);
|
||||
|
||||
try {
|
||||
for await (const part of response) {
|
||||
// process.stdout.write(part.response);
|
||||
buffer.push(part.response);
|
||||
}
|
||||
resolve(buffer);
|
||||
} catch (err) {
|
||||
reject(err);
|
||||
} finally {
|
||||
response.abort();
|
||||
clearTimeout(timeout);
|
||||
};
|
||||
});
|
||||
|
||||
const text = resultBuffer.join('');
|
||||
|
||||
return [
|
||||
{
|
||||
insertText: text,
|
||||
range: new vscode.Range(position, position),
|
||||
}
|
||||
];
|
||||
return completions.map((text) => ({
|
||||
insertText: text,
|
||||
range: new vscode.Range(position, position),
|
||||
}));
|
||||
} catch (err) {
|
||||
console.log(err);
|
||||
}
|
||||
|
@ -143,4 +146,4 @@ export const activate = (context: vscode.ExtensionContext) => {
|
|||
};
|
||||
|
||||
// This method is called when your extension is deactivated
|
||||
export function deactivate() {}
|
||||
export function deactivate() { }
|
||||
|
|
13
src/trie.ts
13
src/trie.ts
|
@ -132,3 +132,16 @@ export const trieLookup = (node: TrieNode, text: string): TrieNode | null => {
|
|||
const child = node.children[childKey];
|
||||
return child === undefined ? null : trieLookup(child, childText);
|
||||
};
|
||||
|
||||
|
||||
export const flattenTrie = (node: TrieNode): string[] => {
|
||||
if (node.isLeaf) {
|
||||
return [
|
||||
node.value
|
||||
];
|
||||
}
|
||||
|
||||
return Object.entries(node.children).flatMap(([key, child]) => {
|
||||
return flattenTrie(child).map((value) => node.value + key + value);
|
||||
});
|
||||
};
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue