added pruning to trie
This commit is contained in:
parent
f93d24dc39
commit
a620903c10
2 changed files with 48 additions and 5 deletions
|
@ -2,7 +2,7 @@
|
||||||
// Import the module and reference it with the alias vscode in your code below
|
// Import the module and reference it with the alias vscode in your code below
|
||||||
import * as vscode from 'vscode';
|
import * as vscode from 'vscode';
|
||||||
import { Ollama } from 'ollama/browser';
|
import { Ollama } from 'ollama/browser';
|
||||||
import { flattenTrie, trieInsert, trieLookup, TrieNode } from './trie';
|
import { flattenTrie, trieInsert, trieLookup, TrieNode, triePrune } from './trie';
|
||||||
|
|
||||||
const MODEL = 'deepseek-coder:6.7b';
|
const MODEL = 'deepseek-coder:6.7b';
|
||||||
|
|
||||||
|
@ -14,6 +14,7 @@ const SUFFIX_END = '</suffixStart>';
|
||||||
|
|
||||||
const MAX_TOKENS = 50;
|
const MAX_TOKENS = 50;
|
||||||
const GENERATION_TIMEOUT = 200;
|
const GENERATION_TIMEOUT = 200;
|
||||||
|
const TRIE_PRUNE_TIMEOUT = 10000;
|
||||||
|
|
||||||
const HOST = undefined;
|
const HOST = undefined;
|
||||||
|
|
||||||
|
@ -70,9 +71,8 @@ const getSuffix = (document: vscode.TextDocument, position: vscode.Position) =>
|
||||||
};
|
};
|
||||||
|
|
||||||
let trieRoot: TrieNode = {
|
let trieRoot: TrieNode = {
|
||||||
isLeaf: false,
|
isLeaf: true,
|
||||||
value: '',
|
value: '',
|
||||||
children: {},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const trieRootInsert = (text: string) => {
|
const trieRootInsert = (text: string) => {
|
||||||
|
@ -83,6 +83,10 @@ const trieRootLookup = (text: string) => {
|
||||||
return trieLookup(trieRoot, text);
|
return trieLookup(trieRoot, text);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const trieRootPrune = (text: string) => {
|
||||||
|
return triePrune(trieRoot, text);
|
||||||
|
};
|
||||||
|
|
||||||
const tokenProvider = async (
|
const tokenProvider = async (
|
||||||
document: vscode.TextDocument,
|
document: vscode.TextDocument,
|
||||||
position: vscode.Position,
|
position: vscode.Position,
|
||||||
|
@ -102,6 +106,10 @@ const tokenProvider = async (
|
||||||
return flattenTrie(result);
|
return flattenTrie(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (token.isCancellationRequested) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
const response = await ollama.generate({
|
const response = await ollama.generate({
|
||||||
model: MODEL,
|
model: MODEL,
|
||||||
prompt,
|
prompt,
|
||||||
|
@ -114,8 +122,15 @@ const tokenProvider = async (
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const pruneTimeout = setTimeout(() => {
|
||||||
|
trieRootPrune(prompt);
|
||||||
|
}, TRIE_PRUNE_TIMEOUT);
|
||||||
|
|
||||||
token.onCancellationRequested(() => {
|
token.onCancellationRequested(() => {
|
||||||
response.abort();
|
clearTimeout(pruneTimeout);
|
||||||
|
try {
|
||||||
|
response.abort();
|
||||||
|
} catch {}
|
||||||
});
|
});
|
||||||
|
|
||||||
const resultBuffer: string[] = await new Promise(async (resolve, reject) => {
|
const resultBuffer: string[] = await new Promise(async (resolve, reject) => {
|
||||||
|
|
30
src/trie.ts
30
src/trie.ts
|
@ -96,6 +96,12 @@ export const trieInsert = (node: TrieNode, text: string): TrieNode => {
|
||||||
return node;
|
return node;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets a new trie that is the node trie after the text
|
||||||
|
* @param node
|
||||||
|
* @param text
|
||||||
|
* @returns
|
||||||
|
*/
|
||||||
export const trieLookup = (node: TrieNode, text: string): TrieNode | null => {
|
export const trieLookup = (node: TrieNode, text: string): TrieNode | null => {
|
||||||
for (let index = 0; index < node.value.length; index++) {
|
for (let index = 0; index < node.value.length; index++) {
|
||||||
// If our node still has text left but we have no more search query return a node starting at where we ran out of characters
|
// If our node still has text left but we have no more search query return a node starting at where we ran out of characters
|
||||||
|
@ -133,7 +139,6 @@ export const trieLookup = (node: TrieNode, text: string): TrieNode | null => {
|
||||||
return child === undefined ? null : trieLookup(child, childText);
|
return child === undefined ? null : trieLookup(child, childText);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
export const flattenTrie = (node: TrieNode): string[] => {
|
export const flattenTrie = (node: TrieNode): string[] => {
|
||||||
if (node.isLeaf) {
|
if (node.isLeaf) {
|
||||||
return [
|
return [
|
||||||
|
@ -145,3 +150,26 @@ export const flattenTrie = (node: TrieNode): string[] => {
|
||||||
return flattenTrie(child).map((value) => node.value + key + value);
|
return flattenTrie(child).map((value) => node.value + key + value);
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const triePrune = (node: TrieNode, text: string): TrieNode => {
|
||||||
|
const value = trieLookup(node, text);
|
||||||
|
if (value === null) {
|
||||||
|
return {
|
||||||
|
isLeaf: true,
|
||||||
|
value: text,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (value.isLeaf) {
|
||||||
|
return {
|
||||||
|
isLeaf: true,
|
||||||
|
value: text + value.value,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
isLeaf: false,
|
||||||
|
value: text + value.value,
|
||||||
|
children: value.children,
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue