From e92df793bc0146c60973b08e965e54af9ab369cd Mon Sep 17 00:00:00 2001 From: Phil Sautter <20444474+redeux@users.noreply.github.com> Date: Thu, 9 Feb 2023 10:10:13 -0500 Subject: [PATCH 1/5] Fix encoder --- dist/index.js | 54 ------- .../index.js | 2 +- dist/utils/{denylist => deny_list}/index.js | 0 dist/utils/index.js | 10 -- dist/utils/language_detection/index.js | 41 ++++++ package-lock.json | 135 ++++++++++++++++++ package.json | 1 + src/index.ts | 19 +-- src/utils/encoder/{index.js => index.ts} | 48 +++---- src/utils/index.ts | 15 +- tsconfig.json | 2 +- 11 files changed, 221 insertions(+), 106 deletions(-) delete mode 100644 dist/index.js rename dist/utils/{attackmitigation => attack_mitigation}/index.js (98%) rename dist/utils/{denylist => deny_list}/index.js (100%) delete mode 100644 dist/utils/index.js create mode 100644 dist/utils/language_detection/index.js rename src/utils/encoder/{index.js => index.ts} (86%) diff --git a/dist/index.js b/dist/index.js deleted file mode 100644 index 8085d44..0000000 --- a/dist/index.js +++ /dev/null @@ -1,54 +0,0 @@ -#!/usr/bin/env ts-node -"use strict"; -var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) { - function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); } - return new (P || (P = Promise))(function (resolve, reject) { - function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } } - function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } } - function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); } - step((generator = generator.apply(thisArg, _arguments || [])).next()); - }); -}; -Object.defineProperty(exports, "__esModule", { value: true }); -exports.PromptGuard = void 0; -const utils_1 = require("./utils"); -var FAILURE_REASON; -(function (FAILURE_REASON) { - FAILURE_REASON["DENY_LIST"] = "CONTAINS_DENY_LIST_ITEM"; - FAILURE_REASON["MAX_TOKEN_THRESHOLD"] = "EXCEEDS_MAX_TOKEN_THRESHOLD"; - FAILURE_REASON["KNOWN_ATTACK"] = "CONTAINS_KNOWN_ATTACK"; -})(FAILURE_REASON || (FAILURE_REASON = {})); -class PromptGuard { - constructor(userPolicyOptions = {}) { - const defaultPromptGuardPolicy = { - maxTokens: 4096, - denyList: [''], - disableAttackMitigation: false, - encodeOutput: false - }; - // merge the user policy with the default policy to create the policy - this.promptGuardPolicy = Object.assign(Object.assign({}, defaultPromptGuardPolicy), userPolicyOptions); - } - process(prompt) { - return __awaiter(this, void 0, void 0, function* () { - // processing order - // normalize -> quote -> escape -> check tokens -> check cache -> check for known attacks -> check allow list -> check deny list -> encode output - // check the prompt token count - if ((0, utils_1.countPromptTokens)(prompt) > this.promptGuardPolicy.maxTokens) - return { pass: false, output: FAILURE_REASON.MAX_TOKEN_THRESHOLD }; - // check prompt against known prompt attacks - if (!this.promptGuardPolicy.disableAttackMitigation) { - if (yield (0, utils_1.promptContainsKnownAttack)(prompt)) - return { pass: false, output: FAILURE_REASON.KNOWN_ATTACK }; - } - // check prompt again the user defined deny list - if (yield (0, utils_1.promptContainsDenyListItems)(prompt, this.promptGuardPolicy.denyList)) - return { pass: false, output: FAILURE_REASON.DENY_LIST }; - // encode the prompt output if encodeOutput is set by the user - if (this.promptGuardPolicy.encodeOutput) - prompt = (0, utils_1.encodePromptOutput)(prompt); - return { pass: true, output: prompt }; - }); - } -} -exports.PromptGuard = PromptGuard; diff --git a/dist/utils/attackmitigation/index.js b/dist/utils/attack_mitigation/index.js similarity index 98% rename from dist/utils/attackmitigation/index.js rename to dist/utils/attack_mitigation/index.js index 5bb722e..24ec570 100644 --- a/dist/utils/attackmitigation/index.js +++ b/dist/utils/attack_mitigation/index.js @@ -12,7 +12,7 @@ Object.defineProperty(exports, "__esModule", { value: true }); exports.containsKnownAttack = void 0; const attackDenyList = [ "ignore above", - "gnore the above", + "ignore the above", "ignore previous instructions", "ignore the previous instructions", "ignore above instructions", diff --git a/dist/utils/denylist/index.js b/dist/utils/deny_list/index.js similarity index 100% rename from dist/utils/denylist/index.js rename to dist/utils/deny_list/index.js diff --git a/dist/utils/index.js b/dist/utils/index.js deleted file mode 100644 index d67766a..0000000 --- a/dist/utils/index.js +++ /dev/null @@ -1,10 +0,0 @@ -"use strict"; -Object.defineProperty(exports, "__esModule", { value: true }); -exports.encodePromptOutput = exports.countPromptTokens = exports.promptContainsKnownAttack = exports.promptContainsDenyListItems = void 0; -const denylist_1 = require("./denylist"); -exports.promptContainsDenyListItems = denylist_1.containsDenyListItems; -const attackmitigation_1 = require("./attackmitigation"); -exports.promptContainsKnownAttack = attackmitigation_1.containsKnownAttack; -const encoder = require("./encoder"); -exports.countPromptTokens = encoder.countTokens; -exports.encodePromptOutput = encoder.encode; diff --git a/dist/utils/language_detection/index.js b/dist/utils/language_detection/index.js new file mode 100644 index 0000000..31f03b1 --- /dev/null +++ b/dist/utils/language_detection/index.js @@ -0,0 +1,41 @@ +"use strict"; +var __awaiter = (this && this.__awaiter) || function (thisArg, _arguments, P, generator) { + function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); } + return new (P || (P = Promise))(function (resolve, reject) { + function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } } + function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } } + function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); } + step((generator = generator.apply(thisArg, _arguments || [])).next()); + }); +}; +var __importDefault = (this && this.__importDefault) || function (mod) { + return (mod && mod.__esModule) ? mod : { "default": mod }; +}; +Object.defineProperty(exports, "__esModule", { value: true }); +exports.containsLanguages = void 0; +const lande_1 = __importDefault(require("lande")); +function containsLanguages(prompt, languages) { + return __awaiter(this, void 0, void 0, function* () { + const detectedLanguages = []; + // lande returns a sorted list of detected languages and their probabilities. + // for now, we're selecting all languages with a probability greater than 80% + // this may need to be tuned later + const landeOuput = (0, lande_1.default)(prompt); + for (const lang of landeOuput) { + if (lang[1] > 0.8) + detectedLanguages.push(lang[0]); + else + break; + } + for (const lang of detectedLanguages) { + if (languages.includes(lang)) + return true; + } + return false; + }); +} +exports.containsLanguages = containsLanguages; +// export async function validateLanguageList(list: string[]): Promise { +// //foo +// return true; +// } diff --git a/package-lock.json b/package-lock.json index ab28a44..8772e2a 100644 --- a/package-lock.json +++ b/package-lock.json @@ -15,6 +15,7 @@ "eslint": "^8.33.0", "jest": "^29.4.1", "ts-jest": "^29.0.5", + "ts-node": "^10.9.1", "typescript": "^4.9.5" } }, @@ -645,6 +646,28 @@ "integrity": "sha512-0hYQ8SB4Db5zvZB4axdMHGwEaQjkZzFjQiN9LVYvIFB2nSUHW9tYpxWriPrWDASIxiaXax83REcLxuSdnGPZtw==", "dev": true }, + "node_modules/@cspotcode/source-map-support": { + "version": "0.8.1", + "resolved": "https://registry.npmjs.org/@cspotcode/source-map-support/-/source-map-support-0.8.1.tgz", + "integrity": "sha512-IchNf6dN4tHoMFIn/7OE8LWZ19Y6q/67Bmf6vnGREv8RSbBVb9LPJxEcnwrcwX6ixSvaiGoomAUvu4YSxXrVgw==", + "dev": true, + "dependencies": { + "@jridgewell/trace-mapping": "0.3.9" + }, + "engines": { + "node": ">=12" + } + }, + "node_modules/@cspotcode/source-map-support/node_modules/@jridgewell/trace-mapping": { + "version": "0.3.9", + "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.9.tgz", + "integrity": "sha512-3Belt6tdc8bPgAtbcmdtNJlirVoTmEb5e2gC94PnkwEW9jI6CAHUeoG85tjWP5WquqfavoMtMwiG4P926ZKKuQ==", + "dev": true, + "dependencies": { + "@jridgewell/resolve-uri": "^3.0.3", + "@jridgewell/sourcemap-codec": "^1.4.10" + } + }, "node_modules/@eslint/eslintrc": { "version": "1.4.1", "resolved": "https://registry.npmjs.org/@eslint/eslintrc/-/eslintrc-1.4.1.tgz", @@ -1193,6 +1216,30 @@ "@sinonjs/commons": "^2.0.0" } }, + "node_modules/@tsconfig/node10": { + "version": "1.0.9", + "resolved": "https://registry.npmjs.org/@tsconfig/node10/-/node10-1.0.9.tgz", + "integrity": "sha512-jNsYVVxU8v5g43Erja32laIDHXeoNvFEpX33OK4d6hljo3jDhCBDhx5dhCCTMWUojscpAagGiRkBKxpdl9fxqA==", + "dev": true + }, + "node_modules/@tsconfig/node12": { + "version": "1.0.11", + "resolved": "https://registry.npmjs.org/@tsconfig/node12/-/node12-1.0.11.tgz", + "integrity": "sha512-cqefuRsh12pWyGsIoBKJA9luFu3mRxCA+ORZvA4ktLSzIuCUtWVxGIuXigEwO5/ywWFMZ2QEGKWvkZG1zDMTag==", + "dev": true + }, + "node_modules/@tsconfig/node14": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/@tsconfig/node14/-/node14-1.0.3.tgz", + "integrity": "sha512-ysT8mhdixWK6Hw3i1V2AeRqZ5WfXg1G43mqoYlM2nc6388Fq5jcXyr5mRsqViLx/GJYdoL0bfXD8nmF+Zn/Iow==", + "dev": true + }, + "node_modules/@tsconfig/node16": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/@tsconfig/node16/-/node16-1.0.3.tgz", + "integrity": "sha512-yOlFc+7UtL/89t2ZhjPvvB/DeAr3r+Dq58IgzsFkOAvVC6NMJXmCGjbptdXdR9qsX7pKcTL+s87FtYREi2dEEQ==", + "dev": true + }, "node_modules/@types/babel__core": { "version": "7.20.0", "resolved": "https://registry.npmjs.org/@types/babel__core/-/babel__core-7.20.0.tgz", @@ -1521,6 +1568,15 @@ "acorn": "^6.0.0 || ^7.0.0 || ^8.0.0" } }, + "node_modules/acorn-walk": { + "version": "8.2.0", + "resolved": "https://registry.npmjs.org/acorn-walk/-/acorn-walk-8.2.0.tgz", + "integrity": "sha512-k+iyHEuPgSw6SbuDpGQM+06HQUa04DZ3o+F6CSzXMvvI5KMvnaEqXe+YVe555R9nn6GPt404fos4wcgpw12SDA==", + "dev": true, + "engines": { + "node": ">=0.4.0" + } + }, "node_modules/ajv": { "version": "6.12.6", "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz", @@ -1601,6 +1657,12 @@ "node": ">= 8" } }, + "node_modules/arg": { + "version": "4.1.3", + "resolved": "https://registry.npmjs.org/arg/-/arg-4.1.3.tgz", + "integrity": "sha512-58S9QDqG0Xx27YwPSt9fJxivjYl432YCwfDMfZ+71RAqUrZef7LrKQZ3LHLOwCS4FLNBplP533Zx895SeOCHvA==", + "dev": true + }, "node_modules/argparse": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/argparse/-/argparse-2.0.1.tgz", @@ -1930,6 +1992,12 @@ "integrity": "sha512-Kvp459HrV2FEJ1CAsi1Ku+MY3kasH19TFykTz2xWmMeq6bk2NU3XXvfJ+Q61m0xktWwt+1HSYf3JZsTms3aRJg==", "dev": true }, + "node_modules/create-require": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/create-require/-/create-require-1.1.1.tgz", + "integrity": "sha512-dcKFX3jn0MpIaXjisoRvexIJVEKzaq7z2rZKxf+MSr9TkdmHmsU4m2lcLojrj/FHl8mk5VxMmYA+ftRkP/3oKQ==", + "dev": true + }, "node_modules/cross-spawn": { "version": "7.0.3", "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz", @@ -1991,6 +2059,15 @@ "node": ">=8" } }, + "node_modules/diff": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/diff/-/diff-4.0.2.tgz", + "integrity": "sha512-58lmxKSA4BNyLz+HHMUzlOEpg09FV+ev6ZMe3vJihgdxzgcwZ8VoEEPmALCZG9LmqfVoNMMKpttIYTVG6uDY7A==", + "dev": true, + "engines": { + "node": ">=0.3.1" + } + }, "node_modules/diff-sequences": { "version": "29.3.1", "resolved": "https://registry.npmjs.org/diff-sequences/-/diff-sequences-29.3.1.tgz", @@ -4464,6 +4541,49 @@ } } }, + "node_modules/ts-node": { + "version": "10.9.1", + "resolved": "https://registry.npmjs.org/ts-node/-/ts-node-10.9.1.tgz", + "integrity": "sha512-NtVysVPkxxrwFGUUxGYhfux8k78pQB3JqYBXlLRZgdGUqTO5wU/UyHop5p70iEbGhB7q5KmiZiU0Y3KlJrScEw==", + "dev": true, + "dependencies": { + "@cspotcode/source-map-support": "^0.8.0", + "@tsconfig/node10": "^1.0.7", + "@tsconfig/node12": "^1.0.7", + "@tsconfig/node14": "^1.0.0", + "@tsconfig/node16": "^1.0.2", + "acorn": "^8.4.1", + "acorn-walk": "^8.1.1", + "arg": "^4.1.0", + "create-require": "^1.1.0", + "diff": "^4.0.1", + "make-error": "^1.1.1", + "v8-compile-cache-lib": "^3.0.1", + "yn": "3.1.1" + }, + "bin": { + "ts-node": "dist/bin.js", + "ts-node-cwd": "dist/bin-cwd.js", + "ts-node-esm": "dist/bin-esm.js", + "ts-node-script": "dist/bin-script.js", + "ts-node-transpile-only": "dist/bin-transpile.js", + "ts-script": "dist/bin-script-deprecated.js" + }, + "peerDependencies": { + "@swc/core": ">=1.2.50", + "@swc/wasm": ">=1.2.50", + "@types/node": "*", + "typescript": ">=2.7" + }, + "peerDependenciesMeta": { + "@swc/core": { + "optional": true + }, + "@swc/wasm": { + "optional": true + } + } + }, "node_modules/tslib": { "version": "1.14.1", "resolved": "https://registry.npmjs.org/tslib/-/tslib-1.14.1.tgz", @@ -4566,6 +4686,12 @@ "punycode": "^2.1.0" } }, + "node_modules/v8-compile-cache-lib": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/v8-compile-cache-lib/-/v8-compile-cache-lib-3.0.1.tgz", + "integrity": "sha512-wa7YjyUGfNZngI/vtK0UHAN+lgDCxBPCylVXGp0zu59Fz5aiGtNXaq3DhIov063MorB+VfufLh3JlF2KdTK3xg==", + "dev": true + }, "node_modules/v8-to-istanbul": { "version": "9.0.1", "resolved": "https://registry.npmjs.org/v8-to-istanbul/-/v8-to-istanbul-9.0.1.tgz", @@ -4697,6 +4823,15 @@ "node": ">=12" } }, + "node_modules/yn": { + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/yn/-/yn-3.1.1.tgz", + "integrity": "sha512-Ux4ygGWsu2c7isFWe8Yu1YluJmqVhxqK2cLXNQA5AcC3QfbGNpM7fu0Y8b/z16pXLnFxZYvWhd3fhBY9DLmC6Q==", + "dev": true, + "engines": { + "node": ">=6" + } + }, "node_modules/yocto-queue": { "version": "0.1.0", "resolved": "https://registry.npmjs.org/yocto-queue/-/yocto-queue-0.1.0.tgz", diff --git a/package.json b/package.json index 8c3ffe8..1686a31 100644 --- a/package.json +++ b/package.json @@ -42,6 +42,7 @@ "eslint": "^8.33.0", "jest": "^29.4.1", "ts-jest": "^29.0.5", + "ts-node": "^10.9.1", "typescript": "^4.9.5" } } diff --git a/src/index.ts b/src/index.ts index 28e806a..d3adc45 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,9 +1,9 @@ #!/usr/bin/env ts-node import { - promptContainsDenyListItems, - countPromptTokens, - encodePromptOutput, - promptContainsKnownAttack + containsDenyListItems, + countTokens, + encode, + containsKnownAttack } from './utils'; enum FAILURE_REASON { @@ -54,24 +54,25 @@ export class PromptGuard { // normalize -> quote -> escape -> check tokens -> check cache -> check for known attacks -> check allow list -> check deny list -> encode output // check the prompt token count - if (countPromptTokens(prompt) > this.promptGuardPolicy.maxTokens) + if (countTokens(prompt) > this.promptGuardPolicy.maxTokens) return { pass: false, output: FAILURE_REASON.MAX_TOKEN_THRESHOLD }; // check prompt against known prompt attacks if (!this.promptGuardPolicy.disableAttackMitigation) { - if (await promptContainsKnownAttack(prompt)) + if (await containsKnownAttack(prompt)) return { pass: false, output: FAILURE_REASON.KNOWN_ATTACK }; } // check prompt again the user defined deny list if ( - await promptContainsDenyListItems(prompt, this.promptGuardPolicy.denyList) + await containsDenyListItems(prompt, this.promptGuardPolicy.denyList) ) return { pass: false, output: FAILURE_REASON.DENY_LIST }; // encode the prompt output if encodeOutput is set by the user - if (this.promptGuardPolicy.encodeOutput) - prompt = encodePromptOutput(prompt); + if (this.promptGuardPolicy.encodeOutput) { + return { pass: true, output: encode(prompt) }; + } return { pass: true, output: prompt }; } diff --git a/src/utils/encoder/index.js b/src/utils/encoder/index.ts similarity index 86% rename from src/utils/encoder/index.js rename to src/utils/encoder/index.ts index af87b99..55f4e94 100644 --- a/src/utils/encoder/index.js +++ b/src/utils/encoder/index.ts @@ -1,43 +1,43 @@ // This file includes code which was modified from https://github.com/openai/gpt-2 // This file inclused code which was modified from https://github.com/NickHeiner/GPT-3-Encoder -// import { path } from "path"; -// import { fs } from "fs"; -// import { util } from "util"; +import * as path from "path"; +import * as fs from "fs"; +import * as util from "util"; -const path = require('path'); -const fs = require('fs'); -const util = require('util'); +// const path = require('path'); +// const fs = require('fs'); +// const util = require('util'); const encoder = JSON.parse( fs.readFileSync(path.join(__dirname, "./encoder.json")) ); const bpe_file = fs.readFileSync(path.join(__dirname, "./vocab.bpe"), "utf-8"); -const range = (x, y) => { +const range = (x: string, y: string) => { const res = Array.from(Array(y).keys()).slice(x); return res; }; -const ord = (x) => { +const ord = (x:string): number => { return x.charCodeAt(0); }; -const chr = (x) => { +const chr = (x: number): string => { return String.fromCharCode(x); }; const textEncoder = new util.TextEncoder("utf-8"); -const encodeStr = (str) => { +const encodeStr = (str:string) => { return Array.from(textEncoder.encode(str)).map((x) => x.toString()); }; const textDecoder = new util.TextDecoder("utf-8"); -const decodeStr = (arr) => { +const decodeStr = (arr: number[]) => { return textDecoder.decode(new Uint8Array(arr)); }; -const dictZip = (x, y) => { +const dictZip = (x: string[], y: string[]) => { const result = {}; x.map((_, i) => { result[x[i]] = y[i]; @@ -70,7 +70,7 @@ function bytes_to_unicode() { return result; } -function get_pairs(word) { +function get_pairs(word: string[]) { const pairs = new Set(); let prev_char = word[0]; for (let i = 1; i < word.length; i++) { @@ -107,13 +107,13 @@ Object.keys(byte_encoder).map((x) => { const bpe_ranks = dictZip(bpe_merges, range(0, bpe_merges.length)); const cache = new Map(); -function bpe(token) { +function bpe(token:string) { if (cache.has(token)) { return cache.get(token); } ``; - let word = token.split(""); + let word: string[] = token.split(""); let pairs = get_pairs(word); @@ -143,7 +143,7 @@ function bpe(token) { const first = bigram[0]; const second = bigram[1]; - let new_word = []; + let new_word: string[] = []; let i = 0; while (i < word.length) { @@ -182,7 +182,7 @@ function bpe(token) { // encoding each match using the encodeStr function and the byte_encoder mapping, // and then applying the bpe function to the encoded token. The number of tokens produced by the bpe function is then added to the count variable. // Finally, the count variable is returned as the result. -function countTokens(text) { +export function countTokens(text: string) { let count = 0; const matches = Array.from(text.matchAll(pat)).map((x) => x[0]); for (let token of matches) { @@ -197,7 +197,7 @@ function countTokens(text) { return count; } -function encode(text) { +export function encode(text) { let bpe_tokens = []; const matches = Array.from(text.matchAll(pat)).map((x) => x[0]); for (let token of matches) { @@ -215,14 +215,14 @@ function encode(text) { return bpe_tokens; } -function decode(tokens) { +export function decode(tokens) { let text = tokens.map((x) => decoder[x]).join(""); text = decodeStr(text.split("").map((x) => byte_decoder[x])); return text; } -module.exports = { - encode, - decode, - countTokens, -}; +// module.exports = { +// encode, +// decode, +// countTokens, +// }; diff --git a/src/utils/index.ts b/src/utils/index.ts index 374678d..03e564c 100644 --- a/src/utils/index.ts +++ b/src/utils/index.ts @@ -1,9 +1,10 @@ -import { containsDenyListItems } from "./denylist"; -export const promptContainsDenyListItems = containsDenyListItems; +export { containsDenyListItems } from "./denylist"; +// export const promptContainsDenyListItems = containsDenyListItems; -import { containsKnownAttack } from "./attackmitigation"; -export const promptContainsKnownAttack = containsKnownAttack; +export { containsKnownAttack } from "./attackmitigation"; +// export const promptContainsKnownAttack = containsKnownAttack; -const encoder = require("./encoder"); -export const countPromptTokens = encoder.countTokens; -export const encodePromptOutput = encoder.encode; +// const encoder = require("./encoder"); +export { countTokens, encode } from './encoder' +// export const countPromptTokens = countTokens; +// export const encodePromptOutput = encode; diff --git a/tsconfig.json b/tsconfig.json index 2fd6387..55e7f86 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -29,7 +29,7 @@ // "paths": {}, /* Specify a set of entries that re-map imports to additional lookup locations. */ // "rootDirs": [], /* Allow multiple folders to be treated as one when resolving modules. */ // "typeRoots": [], /* Specify multiple folders that act like './node_modules/@types'. */ - // "types": [], /* Specify type package names to be included without being referenced in a source file. */ + "types": ["node"], /* Specify type package names to be included without being referenced in a source file. */ // "allowUmdGlobalAccess": true, /* Allow accessing UMD globals from modules. */ // "moduleSuffixes": [], /* List of file name suffixes to search when resolving a module. */ // "resolveJsonModule": true, /* Enable importing .json files. */ From ba56eaf07d116c6568ace3361c1f22dcc19df293 Mon Sep 17 00:00:00 2001 From: Phil Sautter <20444474+redeux@users.noreply.github.com> Date: Thu, 9 Feb 2023 10:12:15 -0500 Subject: [PATCH 2/5] Refactor exports --- src/utils/index.ts | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/utils/index.ts b/src/utils/index.ts index 03e564c..8ed7136 100644 --- a/src/utils/index.ts +++ b/src/utils/index.ts @@ -1,10 +1,3 @@ export { containsDenyListItems } from "./denylist"; -// export const promptContainsDenyListItems = containsDenyListItems; - export { containsKnownAttack } from "./attackmitigation"; -// export const promptContainsKnownAttack = containsKnownAttack; - -// const encoder = require("./encoder"); -export { countTokens, encode } from './encoder' -// export const countPromptTokens = countTokens; -// export const encodePromptOutput = encode; +export { countTokens, encode } from './encoder' \ No newline at end of file From fc5889c9a85e4a0dd7469dd09baaf18a935a7314 Mon Sep 17 00:00:00 2001 From: Phil Sautter <20444474+redeux@users.noreply.github.com> Date: Sun, 30 Apr 2023 09:28:48 -0400 Subject: [PATCH 3/5] Convert encoder to TypeScript --- src/utils/encoder/index.ts | 172 +++++++++++++++++-------------------- 1 file changed, 78 insertions(+), 94 deletions(-) diff --git a/src/utils/encoder/index.ts b/src/utils/encoder/index.ts index 55f4e94..7629255 100644 --- a/src/utils/encoder/index.ts +++ b/src/utils/encoder/index.ts @@ -1,25 +1,23 @@ // This file includes code which was modified from https://github.com/openai/gpt-2 // This file inclused code which was modified from https://github.com/NickHeiner/GPT-3-Encoder -import * as path from "path"; -import * as fs from "fs"; -import * as util from "util"; +import * as fs from 'fs'; +import * as path from 'path'; -// const path = require('path'); -// const fs = require('fs'); -// const util = require('util'); - -const encoder = JSON.parse( - fs.readFileSync(path.join(__dirname, "./encoder.json")) +const encoder: { [key: string]: number } = JSON.parse( + fs.readFileSync(path.join(__dirname, './encoder.json'), 'utf-8') +); +const bpe_file: string = fs.readFileSync( + path.join(__dirname, './vocab.bpe'), + 'utf-8' ); -const bpe_file = fs.readFileSync(path.join(__dirname, "./vocab.bpe"), "utf-8"); -const range = (x: string, y: string) => { +const range = (x: number, y: number): number[] => { const res = Array.from(Array(y).keys()).slice(x); return res; }; -const ord = (x:string): number => { +const ord = (x: string): number => { return x.charCodeAt(0); }; @@ -27,31 +25,26 @@ const chr = (x: number): string => { return String.fromCharCode(x); }; -const textEncoder = new util.TextEncoder("utf-8"); -const encodeStr = (str:string) => { - return Array.from(textEncoder.encode(str)).map((x) => x.toString()); +const textEncoder = new TextEncoder(); +const encodeStr = (str: string): string[] => { + return Array.from(textEncoder.encode(str)).map(x => x.toString()); }; -const textDecoder = new util.TextDecoder("utf-8"); -const decodeStr = (arr: number[]) => { - return textDecoder.decode(new Uint8Array(arr)); +const textDecoder = new TextDecoder(); +const decodeStr = (arr: string[]): string => { + return textDecoder.decode(new Uint8Array(arr.map(Number))); }; -const dictZip = (x: string[], y: string[]) => { - const result = {}; - x.map((_, i) => { - result[x[i]] = y[i]; - }); +function dictZip(x: string[][], y: number[]): { [key: string]: number } { + const result: { [key: string]: number } = {}; + x.map((keyArr, i) => { result[keyArr.join('')] = y[i] }); return result; -}; +} -function bytes_to_unicode() { - const bs = range(ord("!"), ord("~") + 1).concat( - range(ord("¡"), ord("¬") + 1), - range(ord("®"), ord("ÿ") + 1) - ); +function bytes_to_unicode(): { [key: number]: string } { + const bs: number[] = range(ord('!'), ord('~') + 1).concat(range(ord('¡'), ord('¬') + 1), range(ord('®'), ord('ÿ') + 1)); - let cs = bs.slice(); + const cs: number[] = bs.slice(); let n = 0; for (let b = 0; b < 2 ** 8; b++) { if (!bs.includes(b)) { @@ -61,17 +54,16 @@ function bytes_to_unicode() { } } - cs = cs.map((x) => chr(x)); + const csChars: string[] = cs.map(x => chr(x)); - const result = {}; - bs.map((_, i) => { - result[bs[i]] = cs[i]; - }); + const result: { [key: number]: string } = {}; + bs.map((_, i) => { result[bs[i]] = csChars[i] }); return result; } -function get_pairs(word: string[]) { - const pairs = new Set(); + +function get_pairs(word: string[]): Set { + const pairs = new Set(); let prev_char = word[0]; for (let i = 1; i < word.length; i++) { const char = word[i]; @@ -84,36 +76,34 @@ function get_pairs(word: string[]) { const pat = /'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+/gu; -const decoder = {}; -Object.keys(encoder).map((x) => { +const decoder: { [key: number]: string } = {}; +Object.keys(encoder).map(x => { decoder[encoder[x]] = x; }); -const lines = bpe_file.split("\n"); +const lines = bpe_file.split('\n'); -// bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]] -const bpe_merges = lines.slice(1, lines.length - 1).map((x) => { +const bpe_merges: string[][] = lines.slice(1, lines.length - 1).map(x => { return x.split(/(\s+)/).filter(function (e) { return e.trim().length > 0; }); }); const byte_encoder = bytes_to_unicode(); -const byte_decoder = {}; -Object.keys(byte_encoder).map((x) => { - byte_decoder[byte_encoder[x]] = x; +const byte_decoder: { [key: string]: string } = {}; +Object.keys(byte_encoder).map(x => { + byte_decoder[byte_encoder[Number(x)]] = x; }); const bpe_ranks = dictZip(bpe_merges, range(0, bpe_merges.length)); -const cache = new Map(); +const cache: Map = new Map(); -function bpe(token:string) { +function bpe(token: string): string { if (cache.has(token)) { - return cache.get(token); + return cache.get(token) as string; } - ``; - let word: string[] = token.split(""); + let word = token.split(''); let pairs = get_pairs(word); @@ -121,24 +111,26 @@ function bpe(token:string) { return token; } - while (true) { - const minPairs = {}; - Array.from(pairs).map((pair) => { - const rank = bpe_ranks[pair]; + let shouldContinue = true; + + while (shouldContinue) { + const minPairs: { [key: number]: string[] } = {}; + Array.from(pairs).map(pair => { + const rank = bpe_ranks[pair.join('')]; minPairs[isNaN(rank) ? 10e10 : rank] = pair; }); const bigram = minPairs[ Math.min( - ...Object.keys(minPairs).map((x) => { + ...Object.keys(minPairs).map(x => { return parseInt(x); }) ) ]; - if (!(bigram in bpe_ranks)) { - break; + if (!(bigram.join('') in bpe_ranks)) { + shouldContinue = false; } const first = bigram[0]; @@ -150,7 +142,7 @@ function bpe(token:string) { const j = word.indexOf(first, i); if (j === -1) { new_word = new_word.concat(word.slice(i)); - break; + shouldContinue = false; } new_word = new_word.concat(word.slice(i, j)); i = j; @@ -166,63 +158,55 @@ function bpe(token:string) { word = new_word; if (word.length === 1) { - break; + shouldContinue = false; } else { pairs = get_pairs(word); } } - word = word.join(" "); - cache.set(token, word); - - return word; + const wordStr = word.join(''); + cache.set(token, wordStr); + + return wordStr; } -// This function works by iterating through the matches of the pat pattern in the input text, -// encoding each match using the encodeStr function and the byte_encoder mapping, -// and then applying the bpe function to the encoded token. The number of tokens produced by the bpe function is then added to the count variable. -// Finally, the count variable is returned as the result. -export function countTokens(text: string) { - let count = 0; - const matches = Array.from(text.matchAll(pat)).map((x) => x[0]); +function encode(text: string): number[] { + let bpe_tokens: number[] = []; + const matches = Array.from(text.matchAll(pat)).map(x => x[0]); for (let token of matches) { token = encodeStr(token) - .map((x) => { - return byte_encoder[x]; + .map(x => { + return byte_encoder[Number(x)]; }) - .join(""); + .join(''); - count += bpe(token).split(" ").length; + const new_tokens = bpe(token) + .split(' ') + .map(x => encoder[x]); + bpe_tokens = bpe_tokens.concat(new_tokens); } - return count; + return bpe_tokens; } -export function encode(text) { - let bpe_tokens = []; - const matches = Array.from(text.matchAll(pat)).map((x) => x[0]); +function countTokens(text: string): number { + let count = 0; + const matches = Array.from(text.matchAll(pat)).map(x => x[0]); for (let token of matches) { token = encodeStr(token) - .map((x) => { - return byte_encoder[x]; + .map(x => { + return byte_encoder[Number(x)]; }) - .join(""); + .join(''); - const new_tokens = bpe(token) - .split(" ") - .map((x) => encoder[x]); - bpe_tokens = bpe_tokens.concat(new_tokens); + count += bpe(token).split(' ').length; } - return bpe_tokens; + return count; } -export function decode(tokens) { - let text = tokens.map((x) => decoder[x]).join(""); - text = decodeStr(text.split("").map((x) => byte_decoder[x])); +function decode(tokens: number[]): string { + let text = tokens.map(x => decoder[x]).join(''); + text = decodeStr(text.split('').map(x => byte_decoder[x])); return text; } -// module.exports = { -// encode, -// decode, -// countTokens, -// }; +export { encode, decode, countTokens }; \ No newline at end of file From 04b94879ff8100e771df0d979739b889778b3b46 Mon Sep 17 00:00:00 2001 From: Phil Sautter <20444474+redeux@users.noreply.github.com> Date: Sun, 30 Apr 2023 11:15:04 -0400 Subject: [PATCH 4/5] Fix CodeQL failures --- src/utils/encoder/index.ts | 38 ++++++++++++++++++++++++++++++-------- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/src/utils/encoder/index.ts b/src/utils/encoder/index.ts index 7629255..b30bb41 100644 --- a/src/utils/encoder/index.ts +++ b/src/utils/encoder/index.ts @@ -37,12 +37,17 @@ const decodeStr = (arr: string[]): string => { function dictZip(x: string[][], y: number[]): { [key: string]: number } { const result: { [key: string]: number } = {}; - x.map((keyArr, i) => { result[keyArr.join('')] = y[i] }); + x.map((keyArr, i) => { + result[keyArr.join('')] = y[i]; + }); return result; } function bytes_to_unicode(): { [key: number]: string } { - const bs: number[] = range(ord('!'), ord('~') + 1).concat(range(ord('¡'), ord('¬') + 1), range(ord('®'), ord('ÿ') + 1)); + const bs: number[] = range(ord('!'), ord('~') + 1).concat( + range(ord('¡'), ord('¬') + 1), + range(ord('®'), ord('ÿ') + 1) + ); const cs: number[] = bs.slice(); let n = 0; @@ -57,11 +62,12 @@ function bytes_to_unicode(): { [key: number]: string } { const csChars: string[] = cs.map(x => chr(x)); const result: { [key: number]: string } = {}; - bs.map((_, i) => { result[bs[i]] = csChars[i] }); + bs.map((_, i) => { + result[bs[i]] = csChars[i]; + }); return result; } - function get_pairs(word: string[]): Set { const pairs = new Set(); let prev_char = word[0]; @@ -166,13 +172,20 @@ function bpe(token: string): string { const wordStr = word.join(''); cache.set(token, wordStr); - + return wordStr; } function encode(text: string): number[] { + let match: RegExpExecArray | null; + const matches: string[] = []; let bpe_tokens: number[] = []; - const matches = Array.from(text.matchAll(pat)).map(x => x[0]); + const regex = new RegExp(pat); + + while ((match = regex.exec(text)) !== null) { + matches.push(match[0]); + } + for (let token of matches) { token = encodeStr(token) .map(x => { @@ -189,8 +202,17 @@ function encode(text: string): number[] { } function countTokens(text: string): number { + let match: RegExpExecArray | null; + const matches: string[] = []; + const regex = new RegExp(pat); let count = 0; - const matches = Array.from(text.matchAll(pat)).map(x => x[0]); + + while ((match = regex.exec(text)) !== null) { + matches.push(match[0]); + } + + + for (let token of matches) { token = encodeStr(token) .map(x => { @@ -209,4 +231,4 @@ function decode(tokens: number[]): string { return text; } -export { encode, decode, countTokens }; \ No newline at end of file +export { encode, decode, countTokens }; From cb855b856c91c60c89ed26f53964ab05865a5509 Mon Sep 17 00:00:00 2001 From: Phil Sautter <20444474+redeux@users.noreply.github.com> Date: Sun, 30 Apr 2023 11:22:52 -0400 Subject: [PATCH 5/5] Add exported function types --- src/utils/encoder/index.d.ts | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 src/utils/encoder/index.d.ts diff --git a/src/utils/encoder/index.d.ts b/src/utils/encoder/index.d.ts new file mode 100644 index 0000000..c41f613 --- /dev/null +++ b/src/utils/encoder/index.d.ts @@ -0,0 +1,3 @@ +export function encode(text: string): number[]; +export function decode(tokens: number[]): string; +export function countTokens(text: string): number; \ No newline at end of file