diff --git a/package.json b/package.json index 651ba27..ab8165a 100644 --- a/package.json +++ b/package.json @@ -15,15 +15,8 @@ "main": "./dist/extension.js", "contributes": { "commands": [ - { - "command": "ai-code.helloWorld", - "title": "Hello World" - } ] }, - "enabledApiProposals": [ - "inlineCompletionsAdditions" - ], "scripts": { "vscode:prepublish": "npm run package", "compile": "webpack", diff --git a/src/autoComplete.ts b/src/autoComplete.ts new file mode 100644 index 0000000..7c77e64 --- /dev/null +++ b/src/autoComplete.ts @@ -0,0 +1,162 @@ +import { ExtensionState } from "./config"; +import * as vscode from 'vscode'; +import { flattenTrie, trieInsert, trieLookup, TrieNode, triePrune } from "./trie"; + +const getModelSupportsSuffix = async (extension: ExtensionState, model: string) => { + // TODO: get if model supports suffixes and use that if available + + // const response = await ollama.show({ + // model: model + // }) + + // model.capabilities.includes('suffix') + return false; +}; + +const getPrompt = (extension: ExtensionState, document: vscode.TextDocument, position: vscode.Position, prefix: string) => { + const messageHeader = `In an english code base with the file.\nfile:\nproject {PROJECT_NAME}\nfile {FILE_NAME}\nlanguage {LANG}` + .replace("{PROJECT_NAME}", vscode.workspace.name || "Untitled") + .replace("{FILE_NAME}", document.fileName) + .replace("{LANG}", document.languageId); + + const message = `File:\n${extension.configuration.inlineCompletion.prefixStart}`; + + + const prompt = `${messageHeader}\n${message}\n${prefix}`; + + return prompt; +}; + +const getPromptWithSuffix = (extension: ExtensionState, document: vscode.TextDocument, position: vscode.Position, prefix: string) => { + const suffix = document.getText(new vscode.Range(position.line, position.character, document.lineCount - 1, document.lineAt(document.lineCount - 1).text.length)); + + const messageSuffix = `End of the file:\n${extension.configuration.inlineCompletion.suffixStart}\n${suffix}\n${extension.configuration.inlineCompletion.suffixEnd}\n`; + const messagePrefix = `Start of the file:\n${extension.configuration.inlineCompletion.prefixStart}`; + + const messageHeader = `In an english code base with the file.\nfile:\nproject {PROJECT_NAME}\nfile {FILE_NAME}\nlanguage {LANG}\n.` + .replace("{PROJECT_NAME}", vscode.workspace.name || "Untitled") + .replace("{FILE_NAME}", document.fileName) + .replace("{LANG}", document.languageId); + + const prompt = `${messageHeader}\n${messageSuffix}\n${messagePrefix}\n${prefix}`; + + return prompt; +}; + +const getSuffix = (extension: ExtensionState, document: vscode.TextDocument, position: vscode.Position) => { + const suffix = document.getText(new vscode.Range(position.line, position.character, document.lineCount - 1, document.lineAt(document.lineCount - 1).text.length)); + + return suffix; +}; + +let trieRoot: TrieNode = { + isLeaf: true, + value: '', +}; + +const trieRootInsert = (text: string) => { + trieRoot = trieInsert(trieRoot, text); +}; + +const trieRootLookup = (text: string) => { + return trieLookup(trieRoot, text); +}; + +const trieRootPrune = (text: string) => { + return triePrune(trieRoot, text); +}; + +const tokenProvider = async ( + extension: ExtensionState, + document: vscode.TextDocument, + position: vscode.Position, + _context: vscode.InlineCompletionContext, + token: vscode.CancellationToken, +) => { + const prefix = document.getText(new vscode.Range(0, 0, position.line, position.character)); + + const model = extension.configuration.inlineCompletion.model; + + const modelSupportsSuffix = await getModelSupportsSuffix(extension, model); + const prompt = modelSupportsSuffix ? getPrompt(extension, document, position, prefix) : getPromptWithSuffix(extension, document, position, prefix); + + const suffix = modelSupportsSuffix ? getSuffix(extension, document, position) : undefined; + + const result = trieRootLookup(prefix); + + if (result !== null) { + return flattenTrie(result); + } + + if (token.isCancellationRequested) { + return []; + } + + const response = await extension.ollama.generate({ + model, + prompt, + suffix, + raw: true, + stream: true, + options: { + num_predict: extension.configuration.inlineCompletion.maxTokens, + stop: extension.configuration.inlineCompletion.prefixEnds, + }, + }); + + const pruneTimeout = setTimeout(() => { + trieRootPrune(prompt); + }, extension.configuration.inlineCompletion.triePruneTimeout); + + token.onCancellationRequested(() => { + clearTimeout(pruneTimeout); + try { + response.abort(); + } catch { } + }); + + const resultBuffer: string[] = await new Promise(async (resolve, reject) => { + const buffer: string[] = []; + const timeout = setTimeout(() => { + resolve(buffer); + }, extension.configuration.inlineCompletion.generationTimeout); + + try { + for await (const part of response) { + buffer.push(part.response); + trieRootInsert(prefix + buffer.join('')); + } + resolve(buffer); + } catch (err) { + reject(err); + } finally { + response.abort(); + clearTimeout(timeout); + }; + }); + + return [ + resultBuffer.join('') + ]; +}; + +export const getAutoCompleteProvider = (extension: ExtensionState) => { + const provider: vscode.InlineCompletionItemProvider = { + async provideInlineCompletionItems(document, position, context, token) { + try { + const completions = await tokenProvider(extension, document, position, context, token); + + return completions.map((text) => ({ + insertText: text, + range: new vscode.Range(position, position), + })); + } catch (err) { + console.log(err); + } + + return []; + }, + }; + + return provider; +}; \ No newline at end of file diff --git a/src/config.ts b/src/config.ts new file mode 100644 index 0000000..4da9580 --- /dev/null +++ b/src/config.ts @@ -0,0 +1,112 @@ +import { Ollama } from 'ollama/browser'; +import * as vscode from 'vscode'; + +const CONFIG_NAMESPACE = 'ai-code'; + +const KEY_OLLAMA_HOST = 'ollamaHost'; + +const KEY_INLINE_COMPLETION_MODEL = 'inlineCompletion.model'; +const KEY_INLINE_COMPLETION_PREFIX_START = 'inlineCompletion.prefixStart'; +const KEY_INLINE_COMPLETION_PREFIX_END = 'inlineCompletion.prefixEnd'; +const KEY_INLINE_COMPLETION_SUFFIX_START = 'inlineCompletion.suffixStart'; +const KEY_INLINE_COMPLETION_SUFFIX_END = 'inlineCompletion.suffixEnd'; +const KEY_INLINE_COMPLETION_MAX_TOKENS = 'inlineCompletion.maxTokens'; +const KEY_INLINE_COMPLETION_GENERATION_TIMEOUT = 'inlineCompletion.generationTimeout'; +const KEY_INLINE_COMPLETION_TRIE_PRUNE_TIMEOUT = 'inlineCompletion.triePruneTimeout'; + +const DEFAULT_INLINE_COMPLETION_MODEL = 'deepseek-coder:6.7b'; +const DEFAULT_INLINE_COMPLETION_PREFIX_START = ''; +const DEFAULT_INLINE_COMPLETION_PREFIX_ENDS = ['', '', '', '']; +const DEFAULT_INLINE_COMPLETION_SUFFIX_START = ''; +const DEFAULT_INLINE_COMPLETION_SUFFIX_END = ''; +const DEFAULT_INLINE_COMPLETION_MAX_TOKENS = 50; +const DEFAULT_INLINE_COMPLETION_GENERATION_TIMEOUT = 200; +const DEFAULT_INLINE_COMPLETION_TRIE_PRUNE_TIMEOUT = 10000; + +interface ExtensionConfiguration { + ollamaHost: string | undefined + inlineCompletion: { + model: string + prefixStart: string + prefixEnds: string[] + suffixStart: string + suffixEnd: string + maxTokens: number + generationTimeout: number + triePruneTimeout: number + } +}; + +export interface ExtensionState { + configuration: ExtensionConfiguration + ollama: Ollama +} + +export const getExtensionState = (): ExtensionState => { + const extensionConfig = vscode.workspace.getConfiguration(CONFIG_NAMESPACE); + + const configuration: ExtensionConfiguration = { + ollamaHost: extensionConfig.get(KEY_OLLAMA_HOST), + inlineCompletion: { + model: extensionConfig.get(KEY_INLINE_COMPLETION_MODEL) ?? DEFAULT_INLINE_COMPLETION_MODEL, + prefixStart: extensionConfig.get(KEY_INLINE_COMPLETION_PREFIX_START) ?? DEFAULT_INLINE_COMPLETION_PREFIX_START, + prefixEnds: extensionConfig.get(KEY_INLINE_COMPLETION_PREFIX_END)?.split(',') ?? DEFAULT_INLINE_COMPLETION_PREFIX_ENDS, + suffixStart: extensionConfig.get(KEY_INLINE_COMPLETION_SUFFIX_START) ?? DEFAULT_INLINE_COMPLETION_SUFFIX_START, + suffixEnd: extensionConfig.get(KEY_INLINE_COMPLETION_SUFFIX_END) ?? DEFAULT_INLINE_COMPLETION_SUFFIX_END, + maxTokens: extensionConfig.get(KEY_INLINE_COMPLETION_MAX_TOKENS) ?? DEFAULT_INLINE_COMPLETION_MAX_TOKENS, + generationTimeout: extensionConfig.get(KEY_INLINE_COMPLETION_GENERATION_TIMEOUT) ?? DEFAULT_INLINE_COMPLETION_GENERATION_TIMEOUT, + triePruneTimeout: extensionConfig.get(KEY_INLINE_COMPLETION_TRIE_PRUNE_TIMEOUT) ?? DEFAULT_INLINE_COMPLETION_TRIE_PRUNE_TIMEOUT, + }, + }; + + const state: ExtensionState = { + ollama: new Ollama({ + host: configuration.ollamaHost, + }), + configuration, + }; + + + vscode.workspace.onDidChangeConfiguration((event) => { + if (event.affectsConfiguration(`${CONFIG_NAMESPACE}.${KEY_OLLAMA_HOST}`)) { + configuration.ollamaHost = extensionConfig.get(KEY_OLLAMA_HOST); + state.ollama = new Ollama({ + host: configuration.ollamaHost, + }); + } + + if (event.affectsConfiguration(`${CONFIG_NAMESPACE}.${KEY_INLINE_COMPLETION_MODEL}`)) { + configuration.inlineCompletion.model = extensionConfig.get(KEY_INLINE_COMPLETION_MODEL) ?? DEFAULT_INLINE_COMPLETION_MODEL; + } + + if (event.affectsConfiguration(`${CONFIG_NAMESPACE}.${KEY_INLINE_COMPLETION_PREFIX_START}`)) { + configuration.inlineCompletion.prefixStart = extensionConfig.get(KEY_INLINE_COMPLETION_PREFIX_START) ?? DEFAULT_INLINE_COMPLETION_PREFIX_START; + } + + if (event.affectsConfiguration(`${CONFIG_NAMESPACE}.${KEY_INLINE_COMPLETION_PREFIX_END}`)) { + configuration.inlineCompletion.prefixEnds = extensionConfig.get(KEY_INLINE_COMPLETION_PREFIX_END)?.split(',') ?? DEFAULT_INLINE_COMPLETION_PREFIX_ENDS; + } + + if (event.affectsConfiguration(`${CONFIG_NAMESPACE}.${KEY_INLINE_COMPLETION_SUFFIX_START}`)) { + configuration.inlineCompletion.suffixStart = extensionConfig.get(KEY_INLINE_COMPLETION_SUFFIX_START) ?? DEFAULT_INLINE_COMPLETION_SUFFIX_START; + } + + if (event.affectsConfiguration(`${CONFIG_NAMESPACE}.${KEY_INLINE_COMPLETION_SUFFIX_END}`)) { + configuration.inlineCompletion.suffixEnd = extensionConfig.get(KEY_INLINE_COMPLETION_SUFFIX_END) ?? DEFAULT_INLINE_COMPLETION_SUFFIX_END; + } + + if (event.affectsConfiguration(`${CONFIG_NAMESPACE}.${KEY_INLINE_COMPLETION_MAX_TOKENS}`)) { + configuration.inlineCompletion.maxTokens = extensionConfig.get(KEY_INLINE_COMPLETION_MAX_TOKENS) ?? DEFAULT_INLINE_COMPLETION_MAX_TOKENS; + } + + if (event.affectsConfiguration(`${CONFIG_NAMESPACE}.${KEY_INLINE_COMPLETION_GENERATION_TIMEOUT}`)) { + configuration.inlineCompletion.generationTimeout = extensionConfig.get(KEY_INLINE_COMPLETION_GENERATION_TIMEOUT) ?? DEFAULT_INLINE_COMPLETION_GENERATION_TIMEOUT; + } + + if (event.affectsConfiguration(`${CONFIG_NAMESPACE}.${KEY_INLINE_COMPLETION_TRIE_PRUNE_TIMEOUT}`)) { + configuration.inlineCompletion.triePruneTimeout = extensionConfig.get(KEY_INLINE_COMPLETION_TRIE_PRUNE_TIMEOUT) ?? DEFAULT_INLINE_COMPLETION_TRIE_PRUNE_TIMEOUT; + } + }); + + return state; +}; diff --git a/src/extension.ts b/src/extension.ts index 90ddc82..bbba054 100644 --- a/src/extension.ts +++ b/src/extension.ts @@ -1,146 +1,26 @@ // The module 'vscode' contains the VS Code extensibility API // Import the module and reference it with the alias vscode in your code below import * as vscode from 'vscode'; -import { Ollama } from 'ollama/browser'; - -const MODEL = 'deepseek-coder:6.7b'; - -const PREFIX_START = ''; -const PREFIX_END = ''; - -const SUFFIX_START = ''; -const SUFFIX_END = ''; - -const MAX_TOKENS = 50; -const GENERATION_TIMEOUT = 200; - -const HOST = undefined; - -// TODO: make these configurable by extension setting -const ollama = new Ollama({ - host: HOST, -}); - -const getModelSupportsSuffix = async (model: string) => { - // TODO: get if model supports suffixes and use that if available - - // const response = await ollama.show({ - // model: model - // }) - - // model.capabilities.includes('suffix') - return false; -}; - -const getPrompt = (document: vscode.TextDocument, position: vscode.Position) => { - const prefix = document.getText(new vscode.Range(0, 0, position.line, position.character)); - - const messageHeader = `In an english code base with the file.\nfile:\nproject {PROJECT_NAME}\nfile {FILE_NAME}\nlanguage {LANG}` - .replace("{PROJECT_NAME}", vscode.workspace.name || "Untitled") - .replace("{FILE_NAME}", document.fileName) - .replace("{LANG}", document.languageId); - - const message = `File:\n${PREFIX_START}`; - - - const prompt = `${messageHeader}\n${message}\n${prefix}`; - - return prompt; -}; - -const getPromptWithSuffix = (document: vscode.TextDocument, position: vscode.Position) => { - const prefix = document.getText(new vscode.Range(0, 0, position.line, position.character)); - const suffix = document.getText(new vscode.Range(position.line, position.character, document.lineCount - 1, document.lineAt(document.lineCount - 1).text.length)); - - const messageSuffix = `End of the file:\n${SUFFIX_START}\n${suffix}\n${SUFFIX_END}\n`; - const messagePrefix = `Start of the file:\n${PREFIX_START}`; - - const messageHeader = `In an english code base with the file.\nfile:\nproject {PROJECT_NAME}\nfile {FILE_NAME}\nlanguage {LANG}\n.` - .replace("{PROJECT_NAME}", vscode.workspace.name || "Untitled") - .replace("{FILE_NAME}", document.fileName) - .replace("{LANG}", document.languageId); - - const prompt = `${messageHeader}\n${messageSuffix}\n${messagePrefix}\n${prefix}`; - - return prompt; -}; - -const getSuffix = (document: vscode.TextDocument, position: vscode.Position) => { - const suffix = document.getText(new vscode.Range(position.line, position.character, document.lineCount - 1, document.lineAt(document.lineCount - 1).text.length)); - - return suffix; -}; - -const tokenProvider = async ( - document: vscode.TextDocument, - position: vscode.Position, - context: vscode.InlineCompletionContext, - _token: vscode.CancellationToken, -) => { - const modelSupportsSuffix = await getModelSupportsSuffix(MODEL); - const prompt = modelSupportsSuffix ? getPrompt(document, position) : getPromptWithSuffix(document, position); - const suffix = modelSupportsSuffix ? getSuffix(document, position) : undefined; - - const response = await ollama.generate({ - model: MODEL, - prompt, - suffix, - raw: true, - stream: true, - options: { - num_predict: MAX_TOKENS, - stop: [PREFIX_END] - }, - }); - - return response; -}; +import { getExtensionState } from './config'; +import { getAutoCompleteProvider } from './autoComplete'; export const activate = (context: vscode.ExtensionContext) => { console.log('"ai-code" extensions loaded'); - const provider: vscode.InlineCompletionItemProvider = { - async provideInlineCompletionItems(document, position, context, token) { - try { - const response = await tokenProvider(document, position, context, token); + const extension = getExtensionState(); - const resultBuffer: string[] = await new Promise(async (resolve, reject) => { - const buffer: string[] = []; - const timeout = setTimeout(() => { - resolve(buffer); - }, GENERATION_TIMEOUT); + const autoCompleteProvider = getAutoCompleteProvider(extension); - try { - for await (const part of response) { - console.log(part.response); - buffer.push(part.response); - } - resolve(buffer); - } catch (err) { - reject(err); - } finally { - clearTimeout(timeout); - }; - }); + // TODO: code suggestion provider - const text = resultBuffer.join(''); + // TODO: quick fix provider - return [ - { - insertText: text, - range: new vscode.Range(position, position), - } - ]; - } catch (err) { - console.log(err); - } + // TODO: chat provider - return []; - }, - }; + // TODO: agent mode provider - vscode.languages.registerInlineCompletionItemProvider({ pattern: '**' }, provider); + vscode.languages.registerInlineCompletionItemProvider({ pattern: '**' }, autoCompleteProvider); }; // This method is called when your extension is deactivated -export function deactivate() {} +export function deactivate() { } diff --git a/src/test/extension.test.ts b/src/test/extension.test.ts deleted file mode 100644 index 4ca0ab4..0000000 --- a/src/test/extension.test.ts +++ /dev/null @@ -1,15 +0,0 @@ -import * as assert from 'assert'; - -// You can import and use all API from the 'vscode' module -// as well as import your extension to test it -import * as vscode from 'vscode'; -// import * as myExtension from '../../extension'; - -suite('Extension Test Suite', () => { - vscode.window.showInformationMessage('Start all tests.'); - - test('Sample test', () => { - assert.strictEqual(-1, [1, 2, 3].indexOf(5)); - assert.strictEqual(-1, [1, 2, 3].indexOf(0)); - }); -}); diff --git a/src/trie.ts b/src/trie.ts new file mode 100644 index 0000000..3ed2ed9 --- /dev/null +++ b/src/trie.ts @@ -0,0 +1,175 @@ + +interface TrieLeaf { + isLeaf: true + value: string + children?: never +} + +interface TrieBranch { + isLeaf: false + value: string + children: { [key: string]: TrieNode } +} + +export type TrieNode = TrieLeaf | TrieBranch + +/** + * Creates a new TrieNode based on node that has text added to it + * @param node node that we are basing the update on + * @param text text that is being added to the node + * @returns a new node with text added to it + */ +export const trieInsert = (node: TrieNode, text: string): TrieNode => { + // TODO: mutate node to add text to it + for (let index = 0; index < text.length; index++) { + // If the inserted text is longer then the nodes text update the node with the new text + if (index >= node.value.length) { + // If the current node is a leaf we can just replace it with a larger leaf + if (node.isLeaf) { + const newLeaf: TrieLeaf = { + isLeaf: true, + value: text, + }; + return newLeaf; + } + + // If the current node is a branch then we need add the remaining text to one of its children + const childKey = text[index]; + const childText = text.substring(index + 1); + const child = node.children[childKey]; + + const newBranch: TrieBranch = { + isLeaf: false, + value: node.value, + children: { + ...node.children, + [childKey]: child === undefined ? { + isLeaf: true, + value: childText, + } : trieInsert(child, childText) + }, + }; + return newBranch; + } + + // If our inserted text does not match the node then we need to split the node + if (node.value[index] !== text[index]) { + // If the node is a leaf we need to split it into a branch + if (node.isLeaf) { + const newBranch: TrieBranch = { + isLeaf: false, + value: text.substring(0, index), + children: { + [text[index]]: { + isLeaf: true, + value: text.substring(index + 1), + }, + [node.value[index]]: { + isLeaf: true, + value: node.value.substring(index + 1), + }, + }, + }; + return newBranch; + } + + // If the node is a branch then we need to create a new leaf on it + const newBranch: TrieBranch = { + isLeaf: false, + value: text.substring(0, index), + children: { + [text[index]]: { + isLeaf: true, + value: text.substring(index + 1), + }, + [node.value[index]]: { + isLeaf: false, + value: node.value.substring(index + 1), + children: node.children, + }, + }, + }; + return newBranch; + } + } + + return node; +}; + +/** + * Gets a new trie that is the node trie after the text + * @param node + * @param text + * @returns + */ +export const trieLookup = (node: TrieNode, text: string): TrieNode | null => { + for (let index = 0; index < node.value.length; index++) { + // If our node still has text left but we have no more search query return a node starting at where we ran out of characters + if (index >= text.length) { + if (node.isLeaf) { + const newNode: TrieLeaf = { + isLeaf: true, + value: node.value.substring(index), + }; + return newNode; + } + const newNode: TrieBranch = { + isLeaf: false, + value: node.value.substring(index), + children: node.children, + }; + return newNode; + } + + // If we have a difference then there is no match + if (node.value[index] !== text[index]) { + return null; + } + } + + // If we get past the end of the node and it is a leaf then there is no match + if (node.isLeaf) { + return null; + } + + // Continue matching on the child node + const childKey = text[node.value.length]; + const childText = text.substring(node.value.length + 1); + const child = node.children[childKey]; + return child === undefined ? null : trieLookup(child, childText); +}; + +export const flattenTrie = (node: TrieNode): string[] => { + if (node.isLeaf) { + return [ + node.value + ]; + } + + return Object.entries(node.children).flatMap(([key, child]) => { + return flattenTrie(child).map((value) => node.value + key + value); + }); +}; + +export const triePrune = (node: TrieNode, text: string): TrieNode => { + const value = trieLookup(node, text); + if (value === null) { + return { + isLeaf: true, + value: text, + }; + } + + if (value.isLeaf) { + return { + isLeaf: true, + value: text + value.value, + }; + } + + return { + isLeaf: false, + value: text + value.value, + children: value.children, + }; +};