From cb7a7d624baa168c3614bf5c969a8e829e3a6b49 Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Tue, 3 Jan 2023 16:23:12 +1100 Subject: [PATCH 01/44] feat: message separation and parsing --- package-lock.json | 13 ++ package.json | 3 +- src/rpc/errors.ts | 10 + src/rpc/index.ts | 0 src/rpc/types.ts | 109 +++++++++++ src/rpc/utils.ts | 316 ++++++++++++++++++++++++++++++ test-ajv.ts | 37 ++++ test-deepkit-rpc-client.ts | 45 +++++ test-deepkit-rpc-server.ts | 44 +++++ test-dgram.ts | 43 ++++ test-g.ts | 22 +++ test-generator-exception.ts | 90 +++++++++ test-generators.ts | 377 ++++++++++++++++++++++++++++++++++++ test-gg.ts | 211 ++++++++++++++++++++ test-hashing.ts | 37 ++++ test-muxrpc-client.ts | 176 +++++++++++++++++ test-muxrpc-server.ts | 201 +++++++++++++++++++ test-subject.ts | 20 ++ tests/rpc/utils.test.ts | 108 +++++++++++ tests/rpc/utils.ts | 79 ++++++++ 20 files changed, 1940 insertions(+), 1 deletion(-) create mode 100644 src/rpc/errors.ts create mode 100644 src/rpc/index.ts create mode 100644 src/rpc/types.ts create mode 100644 src/rpc/utils.ts create mode 100644 test-ajv.ts create mode 100644 test-deepkit-rpc-client.ts create mode 100644 test-deepkit-rpc-server.ts create mode 100644 test-dgram.ts create mode 100644 test-g.ts create mode 100644 test-generator-exception.ts create mode 100644 test-generators.ts create mode 100644 test-gg.ts create mode 100644 test-hashing.ts create mode 100644 test-muxrpc-client.ts create mode 100644 test-muxrpc-server.ts create mode 100644 test-subject.ts create mode 100644 tests/rpc/utils.test.ts create mode 100644 tests/rpc/utils.ts diff --git a/package-lock.json b/package-lock.json index 64b78c715..eff34456b 100644 --- a/package-lock.json +++ b/package-lock.json @@ -59,6 +59,7 @@ "devDependencies": { "@babel/preset-env": "^7.13.10", "@fast-check/jest": "^1.1.0", + "@streamparser/json": "^0.0.12", "@swc/core": "^1.2.215", "@types/cross-spawn": "^6.0.2", "@types/google-protobuf": "^3.7.4", @@ -3049,6 +3050,12 @@ "@sinonjs/commons": "^1.7.0" } }, + "node_modules/@streamparser/json": { + "version": "0.0.12", + "resolved": "https://registry.npmjs.org/@streamparser/json/-/json-0.0.12.tgz", + "integrity": "sha512-+kmRpd+EeTFd3qNt1AoKphJqbAN26ZDsbiwqjBFeoAmdCyiUO19xMXPtYi9vovAj9a7OAJnvWtiHkwwjU2Fx4Q==", + "dev": true + }, "node_modules/@swc/core": { "version": "1.2.218", "resolved": "https://registry.npmjs.org/@swc/core/-/core-1.2.218.tgz", @@ -14314,6 +14321,12 @@ "@sinonjs/commons": "^1.7.0" } }, + "@streamparser/json": { + "version": "0.0.12", + "resolved": "https://registry.npmjs.org/@streamparser/json/-/json-0.0.12.tgz", + "integrity": "sha512-+kmRpd+EeTFd3qNt1AoKphJqbAN26ZDsbiwqjBFeoAmdCyiUO19xMXPtYi9vovAj9a7OAJnvWtiHkwwjU2Fx4Q==", + "dev": true + }, "@swc/core": { "version": "1.2.218", "resolved": "https://registry.npmjs.org/@swc/core/-/core-1.2.218.tgz", diff --git a/package.json b/package.json index ac86c2522..7631c48c2 100644 --- a/package.json +++ b/package.json @@ -163,6 +163,7 @@ "ts-node": "^10.9.1", "tsconfig-paths": "^3.9.0", "typedoc": "^0.23.21", - "typescript": "^4.9.3" + "typescript": "^4.9.3", + "@streamparser/json": "^0.0.12" } } diff --git a/src/rpc/errors.ts b/src/rpc/errors.ts new file mode 100644 index 000000000..31bd028ec --- /dev/null +++ b/src/rpc/errors.ts @@ -0,0 +1,10 @@ +import { ErrorPolykey, sysexits } from '../errors'; + +class ErrorRpc extends ErrorPolykey {} + +class ErrorRpcParse extends ErrorRpc { + static description = 'Failed to parse Buffer stream'; + exitCode = sysexits.SOFTWARE; +} + +export { ErrorRpc, ErrorRpcParse }; diff --git a/src/rpc/index.ts b/src/rpc/index.ts new file mode 100644 index 000000000..e69de29bb diff --git a/src/rpc/types.ts b/src/rpc/types.ts new file mode 100644 index 000000000..b382bc8d5 --- /dev/null +++ b/src/rpc/types.ts @@ -0,0 +1,109 @@ +import type { POJO } from '../types'; + +/** + * This is the JSON RPC request object. this is the generic message type used for the RPC. + */ +type JsonRpcRequest = { + type: 'JsonRpcRequest'; + // A String specifying the version of the JSON-RPC protocol. MUST be exactly "2.0" + jsonrpc: '2.0'; + // A String containing the name of the method to be invoked. Method names that begin with the word rpc followed by a + // period character (U+002E or ASCII 46) are reserved for rpc-internal methods and extensions and MUST NOT be used + // for anything else. + method: string; + // A Structured value that holds the parameter values to be used during the invocation of the method. + // This member MAY be omitted. + params?: T; + // An identifier established by the Client that MUST contain a String, Number, or NULL value if included. + // If it is not included it is assumed to be a notification. The value SHOULD normally not be Null [1] and Numbers + // SHOULD NOT contain fractional parts [2] + id: string | number | null; +}; + +type JsonRpcNotification = { + type: 'JsonRpcNotification'; + // A String specifying the version of the JSON-RPC protocol. MUST be exactly "2.0" + jsonrpc: '2.0'; + // A String containing the name of the method to be invoked. Method names that begin with the word rpc followed by a + // period character (U+002E or ASCII 46) are reserved for rpc-internal methods and extensions and MUST NOT be used + // for anything else. + method: string; + // A Structured value that holds the parameter values to be used during the invocation of the method. + // This member MAY be omitted. + params?: T; +}; + +type JsonRpcResponseResult = { + type: 'JsonRpcResponseResult'; + // A String specifying the version of the JSON-RPC protocol. MUST be exactly "2.0". + jsonrpc: '2.0'; + // This member is REQUIRED on success. + // This member MUST NOT exist if there was an error invoking the method. + // The value of this member is determined by the method invoked on the Server. + result: T; + // This member is REQUIRED. + // It MUST be the same as the value of the id member in the Request Object. + // If there was an error in detecting the id in the Request object (e.g. Parse error/Invalid Request), + // it MUST be Null. + id: string | number | null; +}; + +type JsonRpcResponseError = { + type: 'JsonRpcResponseError'; + // A String specifying the version of the JSON-RPC protocol. MUST be exactly "2.0". + jsonrpc: '2.0'; + // This member is REQUIRED on error. + // This member MUST NOT exist if there was no error triggered during invocation. + // The value for this member MUST be an Object as defined in section 5.1. + error: JsonRpcError; + // This member is REQUIRED. + // It MUST be the same as the value of the id member in the Request Object. + // If there was an error in detecting the id in the Request object (e.g. Parse error/Invalid Request), + // it MUST be Null. + id: string | number | null; +}; + +// The error codes from and including -32768 to -32000 are reserved for pre-defined errors. Any code within this range, +// but not defined explicitly below is reserved for future use. The error codes are nearly the same as those suggested +// for XML-RPC at the following url: http://xmlrpc-epi.sourceforge.net/specs/rfc.fault_codes.php +// +// code message meaning +// -32700 Parse error Invalid JSON was received by the server. An error occurred on the server while parsing the JSON text. +// -32600 Invalid Request The JSON sent is not a valid Request object. +// -32601 Method not found The method does not exist / is not available. +// -32602 Invalid params Invalid method parameter(s). +// -32603 Internal error Internal JSON-RPC error. +// -32000 to -32099 + +type JsonRpcError = { + // A Number that indicates the error type that occurred. + // This MUST be an integer. + code: number; + // A String providing a short description of the error. + // The message SHOULD be limited to a concise single sentence. + message: string; + // A Primitive or Structured value that contains additional information about the error. + // This may be omitted. + // The value of this member is defined by the Server (e.g. detailed error information, nested errors etc.). + data?: T; +}; + +type JsonRpcResponse = + | JsonRpcResponseResult + | JsonRpcResponseError; + +type jsonRpcMessage = + | JsonRpcRequest + | JsonRpcNotification + | JsonRpcResponseResult + | JsonRpcResponseError; + +export type { + JsonRpcRequest, + JsonRpcNotification, + JsonRpcResponseResult, + JsonRpcResponseError, + JsonRpcError, + JsonRpcResponse, + jsonRpcMessage, +}; diff --git a/src/rpc/utils.ts b/src/rpc/utils.ts new file mode 100644 index 000000000..82ceab7de --- /dev/null +++ b/src/rpc/utils.ts @@ -0,0 +1,316 @@ +import type { Transformer, TransformerTransformCallback } from 'stream/web'; +import type { + JsonRpcError, + jsonRpcMessage, + JsonRpcNotification, + JsonRpcRequest, + JsonRpcResponseError, + JsonRpcResponseResult, +} from 'rpc/types'; +import type { POJO } from '../types'; +import { TransformStream } from 'stream/web'; +import * as utils from 'utils/index'; +import * as validationErrors from 'validation/errors'; +import * as rpcErrors from './errors'; +import { promise } from '../utils'; +const jsonStreamParsers = require('@streamparser/json'); + +class JsonToJsonMessage implements Transformer { + protected buffer = Buffer.alloc(0); + + /** + * This function finds the index of the closing `}` bracket of the top level + * JSON object. It makes use of a JSON parser tokenizer to find the `{}` + * tokens that are not within strings and counts them to find the top level + * matching `{}` pair. + */ + protected async findCompleteMessageIndex(input: Buffer): Promise { + const tokenizer = new jsonStreamParsers.Tokenizer(); + let braceCount = 0; + let escapes = 0; + const foundOffset = promise(); + tokenizer.onToken = (tokenData) => { + if (tokenData.token === jsonStreamParsers.TokenType.LEFT_BRACE) { + braceCount += 1; + } else if (tokenData.token === jsonStreamParsers.TokenType.RIGHT_BRACE) { + braceCount += -1; + if (braceCount === 0) foundOffset.resolveP(tokenData.offset + escapes); + } else if (tokenData.token === jsonStreamParsers.TokenType.STRING) { + const string = tokenData.value as string; + // `JSON.stringify` changes the length of a string when special + // characters are present. This makes the offset we find wrong when + // getting the substring. We need to compensate for this by getting the + // difference in string length. + escapes += JSON.stringify([string]).length - string.length - 4; + } + }; + tokenizer.onEnd = () => foundOffset.resolveP(-1); + try { + tokenizer.write(input); + } catch (e) { + throw new rpcErrors.ErrorRpcParse('TMP StreamParseError', { cause: e }); + } + try { + tokenizer.end(); + } catch { + foundOffset.resolveP(-1); + } + return await foundOffset.p; + } + + transform: TransformerTransformCallback = async ( + chunk, + controller, + ) => { + this.buffer = Buffer.concat([this.buffer, chunk]); + while (this.buffer.length > 0) { + const index = await this.findCompleteMessageIndex(this.buffer); + if (index <= 0) break; + const outputBuffer = this.buffer.subarray(0, index + 1); + try { + controller.enqueue(JSON.parse(outputBuffer.toString('utf-8'))); + } catch (e) { + throw new rpcErrors.ErrorRpcParse(undefined, { cause: e }); + } + this.buffer = this.buffer.subarray(index + 1); + } + }; +} + +// TODO: rename to something more descriptive? +class JsonToJsonMessageStream extends TransformStream { + constructor() { + super(new JsonToJsonMessage()); + } +} + +const messagetypes = [ + 'JsonRpcRequest', + 'JsonRpcNotification', + 'JsonRpcResponseResult', + 'JsonRpcResponseError', +]; + +function parseJsonRpcRequest( + message: unknown, +): JsonRpcRequest { + if (!utils.isObject(message)) { + throw new validationErrors.ErrorParse('must be a JSON POJO'); + } + if (!('type' in message)) { + throw new validationErrors.ErrorParse('`type` property must be defined'); + } + if (typeof message.type !== 'string') { + throw new validationErrors.ErrorParse('`type` property must be a string'); + } + if (message.type !== 'JsonRpcRequest') { + throw new validationErrors.ErrorParse( + '`type` property must be "JsonRpcRequest"', + ); + } + if (!('method' in message)) { + throw new validationErrors.ErrorParse('`method` property must be defined'); + } + if (typeof message.method !== 'string') { + throw new validationErrors.ErrorParse('`method` property must be a string'); + } + if ('params' in message && !utils.isObject(message.params)) { + throw new validationErrors.ErrorParse('`params` property must be a POJO'); + } + if (!('id' in message)) { + throw new validationErrors.ErrorParse('`id` property must be defined'); + } + if ( + typeof message.id !== 'string' && + typeof message.id !== 'number' && + typeof message.id !== null + ) { + throw new validationErrors.ErrorParse( + '`id` property must be a string, number or null', + ); + } + return message as JsonRpcRequest; +} + +function parseJsonRpcNotification( + message: unknown, +): JsonRpcNotification { + if (!utils.isObject(message)) { + throw new validationErrors.ErrorParse('must be a JSON POJO'); + } + if (!('type' in message)) { + throw new validationErrors.ErrorParse('`type` property must be defined'); + } + if (typeof message.type !== 'string') { + throw new validationErrors.ErrorParse('`type` property must be a string'); + } + if (message.type !== 'JsonRpcNotification') { + throw new validationErrors.ErrorParse( + '`type` property must be "JsonRpcRequest"', + ); + } + if (!('method' in message)) { + throw new validationErrors.ErrorParse('`method` property must be defined'); + } + if (typeof message.method !== 'string') { + throw new validationErrors.ErrorParse('`method` property must be a string'); + } + if ('params' in message && !utils.isObject(message.params)) { + throw new validationErrors.ErrorParse('`params` property must be a POJO'); + } + if ('id' in message) { + throw new validationErrors.ErrorParse('`id` property must not be defined'); + } + return message as JsonRpcNotification; +} + +function parseJsonRpcResponseResult( + message: unknown, +): JsonRpcResponseResult { + if (!utils.isObject(message)) { + throw new validationErrors.ErrorParse('must be a JSON POJO'); + } + if (!('type' in message)) { + throw new validationErrors.ErrorParse('`type` property must be defined'); + } + if (typeof message.type !== 'string') { + throw new validationErrors.ErrorParse('`type` property must be a string'); + } + if (message.type !== 'JsonRpcResponseResult') { + throw new validationErrors.ErrorParse( + '`type` property must be "JsonRpcRequest"', + ); + } + if (!('result' in message)) { + throw new validationErrors.ErrorParse('`result` property must be defined'); + } + if ('error' in message) { + throw new validationErrors.ErrorParse( + '`error` property must not be defined', + ); + } + if (!utils.isObject(message.result)) { + throw new validationErrors.ErrorParse('`result` property must be a POJO'); + } + if (!('id' in message)) { + throw new validationErrors.ErrorParse('`id` property must be defined'); + } + if ( + typeof message.id !== 'string' && + typeof message.id !== 'number' && + typeof message.id !== null + ) { + throw new validationErrors.ErrorParse( + '`id` property must be a string, number or null', + ); + } + return message as JsonRpcResponseResult; +} + +function parseJsonRpcResponseError( + message: unknown, +): JsonRpcResponseError { + if (!utils.isObject(message)) { + throw new validationErrors.ErrorParse('must be a JSON POJO'); + } + if (!('type' in message)) { + throw new validationErrors.ErrorParse('`type` property must be defined'); + } + if (typeof message.type !== 'string') { + throw new validationErrors.ErrorParse('`type` property must be a string'); + } + if (message.type !== 'JsonRpcResponseError') { + throw new validationErrors.ErrorParse( + '`type` property must be "JsonRpcResponseError"', + ); + } + if ('result' in message) { + throw new validationErrors.ErrorParse( + '`result` property must not be defined', + ); + } + if (!('error' in message)) { + throw new validationErrors.ErrorParse('`error` property must be defined'); + } + parseJsonRpcError(message.error); + if (!('id' in message)) { + throw new validationErrors.ErrorParse('`id` property must be defined'); + } + if ( + typeof message.id !== 'string' && + typeof message.id !== 'number' && + typeof message.id !== null + ) { + throw new validationErrors.ErrorParse( + '`id` property must be a string, number or null', + ); + } + return message as JsonRpcResponseError; +} + +function parseJsonRpcError(message: unknown): JsonRpcError { + if (!utils.isObject(message)) { + throw new validationErrors.ErrorParse('must be a JSON POJO'); + } + if (!('code' in message)) { + throw new validationErrors.ErrorParse('`code` property must be defined'); + } + if (typeof message.code !== 'number') { + throw new validationErrors.ErrorParse('`code` property must be a number'); + } + if (!('message' in message)) { + throw new validationErrors.ErrorParse('`message` property must be defined'); + } + if (typeof message.message !== 'string') { + throw new validationErrors.ErrorParse( + '`message` property must be a string', + ); + } + if ('data' in message && !utils.isObject(message.data)) { + throw new validationErrors.ErrorParse('`data` property must be a POJO'); + } + return message as JsonRpcError; +} + +function parseJsonRpcMessage( + message: unknown, +): jsonRpcMessage { + if (!utils.isObject(message)) { + throw new validationErrors.ErrorParse('must be a JSON POJO'); + } + if (!('type' in message)) { + throw new validationErrors.ErrorParse('`type` property must be defined'); + } + if (typeof message.type !== 'string') { + throw new validationErrors.ErrorParse('`type` property must be a string'); + } + if (!(message.type in messagetypes)) { + throw new validationErrors.ErrorParse( + '`type` property must be a valid type', + ); + } + if (!('jsonrpc' in message)) { + throw new validationErrors.ErrorParse('`jsonrpc` property must be defined'); + } + if (message.jsonrpc !== '2.0') { + throw new validationErrors.ErrorParse( + '`jsonrpc` property must be a string of "2.0"', + ); + } + switch (message.type) { + case 'JsonRpcRequest': + return parseJsonRpcRequest(message); + case 'JsonRpcNotification': + return parseJsonRpcNotification(message); + case 'JsonRpcResponseResult': + return parseJsonRpcResponseResult(message); + case 'JsonRpcResponseError': + return parseJsonRpcResponseError(message); + default: + throw new validationErrors.ErrorParse( + '`type` property must be a valid type', + ); + } +} + +export { JsonToJsonMessageStream }; diff --git a/test-ajv.ts b/test-ajv.ts new file mode 100644 index 000000000..bec582f79 --- /dev/null +++ b/test-ajv.ts @@ -0,0 +1,37 @@ +import { signedClaimValidate } from './src/claims/schema'; +import { ClaimIdEncoded, SignedClaim } from './src/claims/types'; +import { NodeIdEncoded } from './src/ids/types'; + +async function main () { + + const y: SignedClaim = { + payload: { + jti: 'abc' as ClaimIdEncoded, + nbf: 123, + iat: 456, + seq: 123, + prevClaimId: 'abc' as ClaimIdEncoded, + prevDigest: null, + iss: 'abc' as NodeIdEncoded, + sub: 'abc', + }, + signatures: [{ + protected: { + alg: "BLAKE2b" + }, + header: { + + }, + signature: "abc", + }] + }; + + const x = signedClaimValidate( + y + ); + + console.log(signedClaimValidate.errors); + +} + +main(); diff --git a/test-deepkit-rpc-client.ts b/test-deepkit-rpc-client.ts new file mode 100644 index 000000000..d6fe1be1f --- /dev/null +++ b/test-deepkit-rpc-client.ts @@ -0,0 +1,45 @@ +import { rpc, RpcKernel } from '@deepkit/rpc'; +// import { RpcClient } from '@deepkit/rpc'; +import { RpcWebSocketClient } from '@deepkit/rpc'; +// import { RpcTcpClientAdapter } from '@deepkit/rpc-tcp'; + +interface ControllerI { + hello(title: string): string; + getUser(): Promise; +} + +@rpc.controller('clientController') +class Controller { + @rpc.action() + hello(title: string): string { + return 'Hello ' + title; + } + + @rpc.action() + async getUser(): Promise { + return 'this is a user'; + } +} + +async function main () { + + const client = new RpcWebSocketClient('ws://localhost:8081'); + client.registerController(Controller, 'clientController'); + + const controller = client.controller('myController'); + + + // const result1 = await controller.hello('world'); + // const result2 = await controller.getUser(); + + // console.log(result1); + // console.log(result2); + + // client.disconnect(); +} + +main(); + + + +// instresting diff --git a/test-deepkit-rpc-server.ts b/test-deepkit-rpc-server.ts new file mode 100644 index 000000000..5c24ff59a --- /dev/null +++ b/test-deepkit-rpc-server.ts @@ -0,0 +1,44 @@ +import { rpc, RpcKernel } from '@deepkit/rpc'; +import { RpcWebSocketServer } from '@deepkit/rpc-tcp'; + +@rpc.controller('Con') +class Con { + @rpc.action() + hello(title: string): string { + return 'Hello ' + title; + } + + @rpc.action() + async getUser(): Promise { + return 'this is a user'; + } +} + +async function main () { + + const kernel = new RpcKernel(); + kernel.registerController(Con, 'Con'); + kernel.controllers + kernel.createConnection + kernel.onConnection((conn) => { + conn.clientAddress + conn.controller + conn.handleMessage + conn.myPeerId + conn.onClose + conn.onMessage + conn.writer + }); + + const server = new RpcWebSocketServer(kernel, 'ws://localhost:8081'); + + server.start({ + host: 'localhost', + port: 8081, + }); + + console.log('STARTED'); + // server.close(); +} + +main(); diff --git a/test-dgram.ts b/test-dgram.ts new file mode 100644 index 000000000..f43de4ad4 --- /dev/null +++ b/test-dgram.ts @@ -0,0 +1,43 @@ +import dgram from 'dgram'; + +// No other process should bebound on it? +// Binding to `::` is the default? +// Right I'm just wondering what it is bound to if we just send +// Default is `dns.lookup` +// The signal can be used to close the socket +const socket = dgram.createSocket('udp4'); + +socket.on('message', (msg, rinfo) => { + console.log(msg, rinfo); +}); + +socket.bind(55555, 'localhost', () => { + + const socket2 = dgram.createSocket('udp4'); + // Upon the first send, it will be bound + // But you can send it to different places + // But you don't have to bind it if you don't want to + // But then it will be randomly set upon the first send and repeatedly + socket2.bind(55551); + + socket2.send('abc', 55555, 'localhost', (e) => { + + console.log('done', e); + socket2.send('abc', 55555, 'localhost', (e) => { + console.log('done', e); + + socket2.send('abc', 55555, 'localhost', (e) => { + console.log('done', e); + + // socket.close(); + // socket2.close(); + + }); + + }); + + + }); + +}); + diff --git a/test-g.ts b/test-g.ts new file mode 100644 index 000000000..30300ecca --- /dev/null +++ b/test-g.ts @@ -0,0 +1,22 @@ +function *concatStrings(): Generator { + let result = ''; + while (true) { + const data = yield; + if (data === null) { + return result; + } + result += data; + } +} + +function *combine() { + return (yield* concatStrings()) + 'FINISH'; +} + +const g = combine(); +g.next(); +g.next("a"); +g.next("b"); +g.next("c"); +const r = g.next(null); +console.log(r.value); diff --git a/test-generator-exception.ts b/test-generator-exception.ts new file mode 100644 index 000000000..c0b51c950 --- /dev/null +++ b/test-generator-exception.ts @@ -0,0 +1,90 @@ +import process from 'process'; + +process.on('uncaughtException', () => { + console.log('Exception was uncaught'); +}); + +process.on('unhandledRejection', () => { + console.log('Rejection was unhandled'); +}); + +async function sleep(ms: number): Promise { + return await new Promise((r) => setTimeout(r, ms)); +} + +async function *gf1() { + let c = 0; + while (true) { + await sleep(100); + yield 'G1 string'; + if (c === 5) { + throw new Error('There is an Error!'); + } + c++; + } +} + +async function *gf2() { + while (true) { + await sleep(100); + try { + yield 'G2 string'; + } catch (e) { + // This yield is for the `throw` call + // It ends up being AWAITED FOR + yield; + // Then on the NEXT `next` call they will get an error + // That's how it has to work... LOL + throw(new Error('Wrapped Error')); + } + } +} + +async function main () { + const g1 = gf1(); + for (let i = 0; i < 10; i++) { + try { + console.log(await g1.next()); + } catch (e) { + console.log('Consumed an exception!'); + break; + } + } + + const g2 = gf2(); + setTimeout(async () => { + // await g.return(); + // Async generator + // If the thrown error is NOT caught + // this will return a Promise that REJECTS + // with the exception passed in + // void g2.throw(new Error('There is an Error!')).catch((e) => { + // console.log('IGNORING ERROR: ', e.message); + // }); + + console.log(await g2.throw(new Error('There is an Error!'))); + }, 250); + + for (let i = 0; i < 10; i++) { + try { + console.log(await g2.next()); + } catch (e) { + console.log('Consumed an exception!'); + break; + } + } + console.log(await g2.next()); + +} + +void main(); + +// Ok so when the stream has an exception +// If we use async generator throw +// The async generator is being consumed by the end user +// That exception cannot be passed into the `yield` +// Not even if I wait until the next loop +// Because under the while loop it will try to do that + +// The problem is here... the types will be a bit weird though +// So that's what you have to be careful about diff --git a/test-generators.ts b/test-generators.ts new file mode 100644 index 000000000..c17b898d1 --- /dev/null +++ b/test-generators.ts @@ -0,0 +1,377 @@ +import { Subject } from 'rxjs'; + +// This example demonstrates a simple handler with +// input async generator and output async generator + +async function sleep(ms: number): Promise { + return await new Promise((r) => setTimeout(r, ms)); +} + + +// Echo handler +async function* handler1( + input: AsyncIterableIterator +): AsyncGenerator { + // This will not preserve the `return` + // for await(const chunk of input) { + // yield chunk; + // } + + // This will also not preserve the `return` + // yield* input; + + // If we want to preserve the `return` + // We must use `return` here too + // Note that technically the `any` is required + // At the end, although technically that is not allowed + return yield* input; +} + +async function client1() { + console.log('CLIENT 1 START'); + async function* input() { + yield Buffer.from('hello'); + yield Buffer.from('world'); + return Buffer.from('end'); + } + // Assume the client gets `AsyncIterableIterator` + const output = handler1(input()) as AsyncIterableIterator; + // for await (const chunk of output) { + // console.log(chunk.toString()); + // } + while (true) { + const { done, value } = await output.next(); + if (Buffer.isBuffer(value)) { + console.log(value.toString()); + } else { + console.log('end with nothing'); + } + if (done) { + break; + } + } + console.log('CLIENT 1 STOP'); +} + +// Client Streaming +async function* handler2( + input: AsyncIterableIterator +): AsyncGenerator { + let chunks = Buffer.from(''); + for await(const chunk of input) { + chunks = Buffer.concat([chunks, chunk]); + } + return chunks; +} + +async function client2() { + console.log('CLIENT 2 START'); + async function* input() { + yield Buffer.from('hello'); + yield Buffer.from('world'); + } + const output = handler2(input()) as AsyncIterableIterator; + // Cannot use for..of for returned values + // Because the `return` is not maintained + let done: boolean | undefined = false; + while (!done) { + let value: Buffer; + ({ done, value } = await output.next()); + console.log(value.toString()); + } + console.log('CLIENT 2 STOP'); +} + +// Server streaming +async function* handler3( + _: AsyncIterableIterator +): AsyncGenerator { + // This handler doesn't care about the input + // It doesn't even bother processing it + yield Buffer.from('hello'); + yield Buffer.from('world'); + // Can we use the `return` to indicate an "early close"? + return Buffer.from('end'); + // It is possible to return `undefined` + // return; +} + +async function client3() { + console.log('CLIENT 3 START'); + // The RPC system can default `undefined` to be an empty async generator + const output = handler3((async function* () {})()) as AsyncIterableIterator; + while (true) { + const { done, value } = await output.next(); + if (Buffer.isBuffer(value)) { + console.log(value.toString()); + } else { + console.log('end with nothing'); + } + if (done) { + break; + } + } + console.log('CLIENT 3 STOP'); +} + +// Duplex streaming +// Pull-on both ends +async function *handler4( + input: AsyncIterableIterator +): AsyncGenerator { + // Note that + // the reason why we return `AsyncGenerator` + // Is because technically the user of this + // Can be used with `return()` and `throw()` + // But it is important to realise the types + // Can be more flexible + // We may wish to create our own types to be compatible + + // This concurrently consumes and concurrently produces + // The order is not sequenced + // How do we do this? + // Well something has to indicate consumption + // Something has to indicate production + // But they should be done in parallel + + // This is something that can be done + // by converting them to web streams (but that focuses on buffers) + // Alternatively by converting it to an event emitter? + // Or through rxjs... let's try that soon + + void (async () => { + // It can be expected that the input will end + // when the connection is stopped + // Or if abruptly we must consider the catching an exception + while (true) { + const { done, value } = await input.next(); + if (Buffer.isBuffer(value)) { + console.log('received', value.toString()); + } + if (done) { + console.log('received done'); + break; + } + } + })(); + + let counter = 0; + while (true) { + yield Buffer.from(counter.toString()); + counter++; + } + + // how do we know when to stop consuming? + // remember that once the connection stops + // we need to indicate that when we are finished + // remember that the thing should eventually be done + // otherwise we have a dangling promise + // that's kind of important + // wait we have an issue here + // how do we know when we are "finished"? + // or do we just `void` it? +} + +async function client4() { + console.log('CLIENT 4 START'); + async function *input() { + yield Buffer.from('hello'); + yield Buffer.from('world'); + return; + } + const output = handler4(input()) as AsyncIterableIterator; + console.log(await output.next()); + console.log(await output.next()); + console.log(await output.next()); + console.log(await output.next()); + console.log(await output.next()); + console.log(await output.next()); + + // if we want to "finish" the stream + // we can just stop consuming the `next()` + // But there's an issue here + console.log('CLIENT 4 STOP'); +} + +// How to "break" connection +async function* handler5( + input: AsyncIterableIterator +): AsyncGenerator { + while (true) { + let value, done; + try { + ({ value, done } = await input.next()); + } catch (e) { + console.log('SERVER GOT ERROR:', e.message); + break; + } + console.log('server received', value, done); + yield Buffer.from('GOT IT'); + if (done) { + console.log('server done'); + break; + } + } + return; +} + +async function client5() { + console.log('CLIENT 5 START'); + // In this scenario + async function* input() { + while (true) { + await sleep(100); + try { + yield Buffer.from('hello'); + } catch (e) { + yield; + throw e; + } + } + } + const inputG = input(); + const output = handler5(inputG as AsyncIterableIterator); + setTimeout(() => { + void inputG.throw(new Error('Connection Failed')); + }, 250); + while(true) { + const { done, value } = await output.next(); + console.log('client received', value); + if (done) { + break; + } + } + console.log('CLIENT 5 STOP'); +} + +// Convert to `push` + +// This is a push based system +// if you don't answer it +// let's see +const subject = new Subject(); + +subject.subscribe({ + next: (v) => console.log('PUSH:', v) +}); + +async function *handler6 ( + input: AsyncIterableIterator +): AsyncGenerator { + + // This is "done" asynchronously, while we pull from the stream + // How to do this in an asynchronus way? + const p = (async () => { + while (true) { + const { value, done } = await input.next(); + subject.next(value); + if (done) { + break; + } + } + })(); + + await sleep(100); + + yield Buffer.from('Hello'); + yield Buffer.from('World'); + // The stream is FINISHED + // but is the function call still completing? + // Consider what happens if that is the case + // We may want the function's lifecycle to be more complete + + // Await to finish this + // This is what allows you to capture any errors + // And the RPC system to throw it back up!!! + await p; + + // This sort of means that the OUTPUT stream isn't finished + // UNTIL you are finishign the INPUT stream + // This is a bit of a problem + // You can also send it earlier + // But if there's an exception in the processing... + + // WELL YEA... you'd need to keep the output stream open + // while you are consuming data + // otherwise you cannot signal that something failed + return; +} + +async function client6() { + console.log('CLIENT 6 START'); + // In this scenario + async function* input() { + yield Buffer.from('first'); + yield Buffer.from('second'); + return; + } + const output = handler6(input()); + while(true) { + const { done, value } = await output.next(); + console.log('client received', value); + if (done) { + break; + } + } + console.log('CLIENT 6 STOP'); +} + + +async function main() { + // await client1(); + // await client2(); + // await client3(); + // await client4(); + // await client5(); + await client6(); +} + +void main(); + +// We assume that the RPC wrapper would plumb the async generator data +// into the underlying web stream provided by the transport layer + +// The async generator `return` can be used to indicate and early +// finish to the the stream +// If `return;` is used, no last chunk is written +// If `return buf;` is used, then the buf is the last chunk to be written +// It also means the `value` could be `undefined` + +// It is possible to "force" a `return` to be applied on the outside +// this mean the `input` stream can be used +// Abort signal can also be used to indicate asynchronous cancellation +// But that is supposed to be used to cancel async operations +// Does the async generator for input stream also get a `ctx`? +// What about `throw`? Does it cancel the stream? + +// What about `ixjs`? Should this embed `ixjs`, so it can be more easily +// used? Technically ixjs works on the iterable, not on the generator +// It doesn't maintain the generator itself right? +// It would be nice if it was still a generator. + +// How to deal with metadata? For authentication... +// Is it part of the RPC system, leading and trailing metadata? +// Each message could have a metadata +// Does it depend on the RPC system itself? +// What about the transport layer? + +// Also if we enable generators +// we technically can communicate back +// that should be disallowed (since it doesn't make sense) +// perhaps we can use a different type +// Like `AsyncIterableIterator` instead of the same thing? +// It limits it to `next` +// Which is interesting + +// I think this is more correct +// You want to "take" an AsyncIterableIterator +// But the client side would get an AsyncIterableIterator +// But pass in an AsyncGenerator +// I think this makes more sense... +// async function *lol( +// x: AsyncIterableIterator +// ): AsyncGenerator { +// yield Buffer.from('hello'); +// return; +// } + diff --git a/test-gg.ts b/test-gg.ts new file mode 100644 index 000000000..90f3e7d88 --- /dev/null +++ b/test-gg.ts @@ -0,0 +1,211 @@ +import fc from 'fast-check'; +import type { ClaimIdEncoded, IdentityId, NodeId, ProviderId } from './src/ids'; +import { DB } from '@matrixai/db'; +import ACL from './src/acl/ACL'; +import GestaltGraph from './src/gestalts/GestaltGraph'; +import { IdInternal } from '@matrixai/id'; +import Logger, { LogLevel, StreamHandler, formatting } from '@matrixai/logger'; +import * as ids from './src/ids'; + +const nodeIdArb = fc.uint8Array({ minLength: 32, maxLength: 32 }).map( + IdInternal.create +) as fc.Arbitrary; + +// const nodeId = IdInternal.fromBuffer(Buffer.allocUnsafe(32)); + +async function main() { + + // Top level + // but we cannot raise the bottom level + // we can only hide levels + // or filter + // You could also set a filter + + const logger = new Logger( + 'TEST', + LogLevel.DEBUG, + [ + new StreamHandler( + formatting.format`${formatting.level}:${formatting.keys}:${formatting.msg}` + ), + ] + ); + + const dbLogger = logger.getChild('DB'); + dbLogger.setLevel(LogLevel.INFO); + + const db = await DB.createDB({ + dbPath: 'tmp/db', + logger: dbLogger, + fresh: true, + }); + + const aclLogger = logger.getChild('ACL'); + aclLogger.setLevel(LogLevel.INFO); + + const acl = await ACL.createACL({ + db, + logger: aclLogger, + }); + + + const ggLogger = logger.getChild('GestaltGraph'); + ggLogger.setLevel(LogLevel.DEBUG); + + const gg = await GestaltGraph.createGestaltGraph({ + db, + acl, + logger: ggLogger, + }); + + const nodeId1 = fc.sample(nodeIdArb, 1)[0]; + + + await gg.setNode({ + nodeId: nodeId1 + }); + + const nodeId2 = fc.sample(nodeIdArb, 1)[0]; + + await gg.setNode({ + nodeId: nodeId2, + }); + + const nodeId3 = fc.sample(nodeIdArb, 1)[0]; + + await gg.setNode({ + nodeId: nodeId3, + }); + + const nodeId4 = fc.sample(nodeIdArb, 1)[0]; + + await gg.setNode({ + nodeId: nodeId4, + }); + + const nodeId5 = fc.sample(nodeIdArb, 1)[0]; + + await gg.setNode({ + nodeId: nodeId5, + }); + + await gg.setIdentity({ + providerId: '123' as ProviderId, + identityId: 'abc' as IdentityId + }); + + await gg.linkNodeAndNode( + { + nodeId: nodeId1 + }, + { + nodeId: nodeId2 + }, + { + meta: {}, + claim: { + payload: { + iss: ids.encodeNodeId(nodeId1), + sub: ids.encodeNodeId(nodeId2), + jti: 'asfoiuadf' as ClaimIdEncoded, + iat: 123, + nbf: 123, + seq: 123, + prevClaimId: null, + prevDigest: null + }, + signatures: [] + } + } + ); + + await gg.linkNodeAndNode( + { + nodeId: nodeId1 + }, + { + nodeId: nodeId3 + }, + { + meta: {}, + claim: { + payload: { + iss: ids.encodeNodeId(nodeId1), + sub: ids.encodeNodeId(nodeId3), + jti: 'asfoiuadf' as ClaimIdEncoded, + iat: 123, + nbf: 123, + seq: 123, + prevClaimId: null, + prevDigest: null + }, + signatures: [] + } + } + ); + + await gg.linkNodeAndNode( + { + nodeId: nodeId2 + }, + { + nodeId: nodeId3 + }, + { + meta: {}, + claim: { + payload: { + iss: ids.encodeNodeId(nodeId2), + sub: ids.encodeNodeId(nodeId3), + jti: 'asfoiuadf' as ClaimIdEncoded, + iat: 123, + nbf: 123, + seq: 123, + prevClaimId: null, + prevDigest: null + }, + signatures: [] + } + } + ); + + // await gg.linkNodeAndNode( + // { + // nodeId: nodeId1 + // }, + // { + // nodeId: nodeId2 + // }, + // { + // type: 'node', + // meta: {}, + // claim: { + // payload: { + // jti: 's8d9sf98s7fd98sfd7' as ClaimIdEncoded, + // iss: ids.encodeNodeId(nodeId1), + // sub: ids.encodeNodeId(nodeId2), + // iat: 123, + // nbf: 123, + // seq: 123, + // prevClaimId: null, + // prevDigest: null + // }, + // signatures: [] + // } + // } + // ); + + console.log(await db.dump(gg.dbMatrixPath, true)); + // console.log(await db.dump(gg.dbNodesPath, true)); + // console.log(await db.dump(gg.dbLinksPath, true)); + + for await (const gestalt of gg.getGestalts()) { + console.group('Gestalt'); + console.dir(gestalt, { depth: null }); + // console.log('nodes', gestalt.nodes); + console.groupEnd(); + } + +} + +main(); diff --git a/test-hashing.ts b/test-hashing.ts new file mode 100644 index 000000000..cc8e4eed7 --- /dev/null +++ b/test-hashing.ts @@ -0,0 +1,37 @@ +import * as hash from './src/keys/utils/hash'; +import * as hashing from './src/tokens/utils'; + +async function main () { + + // thisis what it takes to do it + + const digest = hash.sha256(Buffer.from('hello world')); + console.log(hashing.sha256MultiHash(digest)); + + + + // const encodeR = await hashing.sha256M.encode(Buffer.from('abc')); + // const digestR = await hashing.sha256M.digest(Buffer.from('abc')); + + // console.log(encodeR.byteLength); + // console.log(encodeR); + + // console.log(digestR); + + // // so remember + // // that upon hashing, you have a multihash digest + + // // this is the actual byte reprentation + // // the remaining stuff still needs to be "multibase" encoded + // console.log(digestR.bytes); + + + // // so therefore + // // BASEENCODING + MULTIHASH is exactly what you want + + + + +} + +main(); diff --git a/test-muxrpc-client.ts b/test-muxrpc-client.ts new file mode 100644 index 000000000..15f86d3de --- /dev/null +++ b/test-muxrpc-client.ts @@ -0,0 +1,176 @@ +import MRPC from 'muxrpc'; +import pull from 'pull-stream'; +import toPull from 'stream-to-pull-stream'; +import net from 'net'; +import { sleep } from './src/utils'; + +const manifest = { + hello: 'async', + // another: 'async', + stuff: 'source', + + sink: 'sink', + + duplex: 'duplex', +}; + +// Client needs the remote manifest, it can pass a local manifest, and then a local API + +// Remote manifest, local manifest, codec +// Local API, Permissions, ID +const client = MRPC(manifest, null)(null, null, 'CLIENT'); + +console.log(client); + +const stream = toPull.duplex(net.connect(8080)); + +// const onClose = () => { +// console.log('closed connection to muxrpc server'); +// }; + + +const mStream = client.createStream(); + +// This also takes a duplex socket, and converts it to a "pull-stream" +pull(stream, mStream, stream); + +// So now that the client is composed there +// Also interestingly... notice that the TCP socket above does its own establishment +// The RPC hence is "transport" agnostic because it works on anything that is a duplex stream +// That's pretty good + +client.hello('world', 100, (err, data) => { + if (err != null) throw err; + console.log('HELLO call1', data); +}); + +// Oh cool, it does support promises on the client side + +// client.hello('world', 50, (err, data) => { +// if (err != null) throw err; +// console.log('HELLO call2', data); +// }); + +// client.hello('world', 10, (err, data) => { +// if (err != null) throw err; +// console.log('HELLO call3', data); +// }); + +// Yep there's a muxing of the RPC calls here +// This makes alot of sense + +// No deadline... it's not finished +// Ok then there's a failure, we have 1 stream per rpc +// client.another('another').then((data) => { +// console.log('ANOTHER', data); +// }); + +console.log('SENT all hellos over'); + +// Now if you want to do a stream, it seems `pull.values` ultimately returns some sort of stream object + +// const s = client.stuff(); +// // Yea this becomes a "source" stream +// // So it is infact the same type +// console.log('stuff stream', s); + +// pull(s, pull.drain(console.log)); + +// So how does this actually "mux" the RPC calls? +// Can they be concurrent? +// I think the muxing still has to end up +// interleaving the data... +// So it's still in sequence +// But the order can be different depending on the situation + + +// This is a sink +// we need to feed datat to the seek +// How do we know when things are done...? +// client.sink( +// pull.values(['hello', 'world']), +// (e, v) => { +// console.log('got it', v); +// } +// ); + +// const sink = client.sink(); + +// pull( +// pull.values(['hello', 'world']), +// sink +// ); + +// When a "stream" is opened here +// it prevents the process from shutting down +// That's kind of bad +// We don't really want to keep the process open + +// const duplex = client.duplex(); + +// console.log('DUPLEX', duplex); + +// pull( +// pull.values([1, 2, 3, 'end']), +// duplex, +// pull.drain(console.log) +// ); + +// Nothing is "ending" the stream +// that's the problem + +// console.log('YO'); + + +// The entire MUXRPC object is an event emitter + +// console.log(client.id); + +console.log(mStream); + +// This is also asynchronous +// It ends up closing a "ws" +// Which usese `initStream` +// YOU HAVE TO SUPPLY A CALLBACK +// client.end(); + +console.log('is open', mStream.isOpen()); + +// I think this is actually wht is perfomring a remote call +// whatever... +// console.log('remote call', mStream.remoteCall.toString()); + +mStream.close(() => { + console.log('CLOSING MUXRPC STREAM'); + + // Closing the stream also closes the client + // The client can create a stream... + // That's really strange + // Ok but the client is then closed too? + console.log(client.closed); +}); + + + +// client.close(() => { +// console.log('ClOSED'); +// }); + +// Remember TCP sockets are duplex streams +// So they are already duplex concurrently +// They are also event emitters at the same time + +// But dgram sockets are EventEmitters +// They are not Duplex stream +// They are not streams at all +// Which makes sense +// But as an event emitter, that makes them concurrent in both directions too +// Messages could be sent there and also received +// All UDP datagrams can be sent to alternative destinations, even if bound to the same socket + +// So this is very interesting + +// I'm still seeing a problem. +// How does the handlers get context of the RPC call? And how does it get access to the remote side's manifest? + + diff --git a/test-muxrpc-server.ts b/test-muxrpc-server.ts new file mode 100644 index 000000000..f24adee77 --- /dev/null +++ b/test-muxrpc-server.ts @@ -0,0 +1,201 @@ +import MRPC from 'muxrpc'; +import pull from 'pull-stream'; +import toPull from 'stream-to-pull-stream'; +import net from 'net'; +import pushable from 'pull-pushable'; +import { sleep } from './src/utils'; + +// "dynamic" manifest + +const manifest = { + hello: 'async', + // another: 'async', + stuff: 'source', + + // There's also `sink` streams + sink: 'sink', + + duplex: 'duplex', +}; + +// actual handlers, notice no promise support +const api = { + async hello(name, time, cb) { + // How are we supposed to know who contacted us? + // This is kind of stupid + + await sleep(time); + cb(null, 'hello' + name + '!'); + }, + // async another (name) { + // return 'hello' + name; + // }, + stuff() { + const s = pull.values([1,2,3,4,5]); + + // Yes, a "source" is a function + // This is a function + // Remember the "stream" is already mutated + // But this ends up returning a function + // This function basically starts the source + // The `cb` is used to read the data + // The cb gets called and it receives the data from the source + // That's where things get a bit complicated + // So the type of this is complex + // console.log('is this a source', s.toString()); + + // IT RETURNS THE SOURCE STREAM + return s; + }, + sink() { + // Cause it is a sink, it only takes data + // It does not give you back any confirmation back? + // IT RETURNS the sink stream... think about that + return pull.collect( + (e, arr) => { + console.log('SUNK', arr); + } + ); + }, + duplex() { + + // This needs to return a source and sink together + // Parameters are still passable into it at the beginning + // Sort of like how our duplex streams are structured + // We are able to pass it the initial message + + // The source cannot be `pull.values` + // Because it ends the stream + // That seems kind of wrong + + const p = pushable(); + // for (let i = 0; i < 5; i++) { + // p.push(i); + // } + // But this seems difficult to use + // How do we "consume" things, and then simultaneously + // push things to the source? + + return { + source: p, + sink: pull.drain((value) => { + // Wait I'm confused, how does this mean it ends? + // How do I know when this is ended? + + // If the `p` doesn't end + // We end up with a problem + + if (value === 'end') { + p.end(); + } + + p.push(value + 1); + }) + }; + }, +}; + +// Remote manifest, local manifest, codec +// Local API, Permissions, ID +const server = MRPC(null, manifest)(api, null, 'SERVER'); + +console.log(server); + +const muxStream = server.createStream(); + +net.createServer(socket => { + + console.log('NEW CONNECTION'); + + // The socket is converted to a duplex pull-stream + // Stream to Pull Stream is a conversion utility + // Converts NodeJS streams (classic-stream and new-stream) into a pull-stream + // It returns an object structure { source, sink } + // It is source(stream) and sink(stream, cb) + // The source will attach handlers on the `data`, `end`, `close`, and `error` events + // It pushes the stream data into the a buffers list. + // It also switches the stream to `rsume`, so it may be paused + // If the length exists, and the stream is pausable, then it will end up calling `stream.pause()` + // So upon handling a single `data` event, it will end up pausing the stream + // The idea being is that the buffer will have data + // Ok I get the idea... + // If the buffer still has data even after calling `drain`, then that means there's already data queued up + // That's why it pauses the stream + // If a stream is paused, data events will not be emitted + // The drawin runs a loop as long as there's data in the queue or ended, and the cbs is filled + // The dbs are callbacks, it shifts one of them, and applies it to a chunk of data + // Then it will check if the length is empty, and it is paused, it will then unpause it, and resume the stream + // On the write side, it is attaching handlers to the stream as well + // This time on close, finish and error. + // On the next tick ,it is then calling the `read` function + // Because it has to "read" a data from the source + // This callback then is given the data + // The data is written with `stream.write(data)` + // Anyway the point is that these is object of 2 things + // { source, sink } + const stream = toPull.duplex(socket); + + // This connects the output (source) of the stream to the muxrpc stream + // And it connects the output of the muxRPC stream to the net stream + // This is variadic function, it can take multiple things that are streams + // It pulls from left to right + // So the stream source is pulled into the muxrpc stream and then pulled into the stream again + // NET SOCKET SOURCE -> MUXRPC -> NET SOCKET SINK + // The duplex socket is being broken apart into 2 separate streams + // Then they get composed with the input and output of the muxrpc stream + // And therefore muxrpc is capable multiplexing its internal messages on top of the base socket streams + + + pull(stream, muxStream, stream); + + // every time a new connection occurs + // we have to do something different... + // I think... otherwise the streams get screwed up + // But I'm not entirely sure + // How do we do this over and over + // We have to "disconnect" + + socket.on('close', () => { + console.log('SOCKET CLOSED'); + // muxStream.close(); + }); + + socket.on('end', () => { + console.log('SOCKET ENDED'); + }); + + +}).listen(8080); + +// In a way, this pull-stream is kind of similar to the idea behind rxjs +// But it's very stream heavy, lack of typescript... etc +// Also being a pull stream, it only pulls from the source when it needs to +// I'm not sure what actually triggers it, it seems the source is given a "read" function +// So when the sink is ready, it ends up calling the `read` function + +// The looper is used for "asynchronous" looping +// Because it uses the `next` call to cycle +// This is necessary in asynchronous callbacks +// However this is not necessary if we are using async await syntax +// In a callback situation you cannot just use `while() { ... }` loop +// But you casn when using async await +// I'm confused about the `function (next) { .. }` cause the looper is not passing +// anything into the `next` parameter, so that doesn't make sense +// Right it is using the 3.0.0 of looper, which had a different design +// Ok so the point is, it's a process next tick, with an asynchronous infinite loop +// The loop repeatedly calls `read` upon finishing the `write` callback +// And it will do this until the read callback is ended +// Or if the output stream itself is ended +// This is what gives it a natural form of backpressure +// It will only "pull" something as fast as the sink can take it +// Since it only triggers a pull, when the sink is drained + +// An async iterator/iterator can do the same thing +// And thus one can "next" it as fast as the sink can read + +// Similarly a push stream would be subscribing, but it's possible to backpressure this with +// buffers or with dropping systems... naturally buffers should be used, and the application can drop data + +// This is an old structure, and I prefer modern JS with functional concepts + +// Now that the server is listening, we can create the client diff --git a/test-subject.ts b/test-subject.ts new file mode 100644 index 000000000..dda09ede9 --- /dev/null +++ b/test-subject.ts @@ -0,0 +1,20 @@ +import { Subject } from 'rxjs'; + +const subject = new Subject(); + +// These are dropped, nobody is listening +subject.next(1); + +subject.subscribe({ + next: (v) => console.log(`observerA: ${v}`) +}); + +subject.next(2); + +// B only gets 3 and 4 +subject.subscribe({ + next: (v) => console.log(`observerB: ${v}`) +}); + +subject.next(3); +subject.next(4); diff --git a/tests/rpc/utils.test.ts b/tests/rpc/utils.test.ts new file mode 100644 index 000000000..e0b767eb8 --- /dev/null +++ b/tests/rpc/utils.test.ts @@ -0,0 +1,108 @@ +import type { JsonRpcRequest } from '@/rpc/types'; +import type { POJO } from '@/types'; +import { ReadableStream } from 'stream/web'; +import { testProp, fc } from '@fast-check/jest'; +import { AsyncIterableX as AsyncIterable } from 'ix/asynciterable'; +import { JsonToJsonMessageStream } from '@/rpc/utils'; +import 'ix/add/asynciterable-operators/toarray'; +import * as rpcErrors from '@/rpc/errors'; +import { + BufferStreamToNoisyStream, + BufferStreamToSnippedStream, +} from './utils'; + +const jsonRpcRequestArb = ( + method: fc.Arbitrary = fc.string().map((value) => 'rpc-' + value), + params: fc.Arbitrary = fc.jsonValue(), + requireParams: boolean = false, +) => { + const requiredKeys: ('jsonrpc' | 'method' | 'params' | 'id')[] = requireParams + ? ['params', 'jsonrpc', 'method', 'id'] + : ['jsonrpc', 'method', 'id']; + return fc.record( + { + jsonrpc: fc.constant('2.0'), + method, + params: params.map((params) => JSON.parse(JSON.stringify(params))), + id: fc.integer({ min: 0 }), + }, + { + requiredKeys, + }, + ) as fc.Arbitrary>; +}; + +describe('utils tests', () => { + const jsonRpcStream = (messages: Array) => { + return new ReadableStream({ + async start(controller) { + for (const arrayElement of messages) { + // Controller.enqueue(arrayElement) + controller.enqueue( + Buffer.from(JSON.stringify(arrayElement), 'utf-8'), + ); + } + controller.close(); + }, + }); + }; + + const snippingPatternArb = fc + .array(fc.integer({ min: 1, max: 32 }), { minLength: 100, size: 'medium' }) + .noShrink(); + + const jsonMessagesArb = fc + .array(jsonRpcRequestArb(), { minLength: 2 }) + .noShrink(); + + testProp( + 'can parse json stream', + [jsonMessagesArb], + async (messages) => { + const parsedStream = jsonRpcStream(messages).pipeThrough( + new JsonToJsonMessageStream(), + ); // Converting back. + + const asd = await AsyncIterable.as(parsedStream).toArray(); + expect(asd).toEqual(messages); + }, + { numRuns: 1000 }, + ); + + testProp( + 'can parse json stream with random chunk sizes', + [jsonMessagesArb, snippingPatternArb], + async (messages, snippattern) => { + const parsedStream = jsonRpcStream(messages) + .pipeThrough(new BufferStreamToSnippedStream(snippattern)) // Imaginary internet here + .pipeThrough(new JsonToJsonMessageStream()); // Converting back. + + const asd = await AsyncIterable.as(parsedStream).toArray(); + expect(asd).toStrictEqual(messages); + }, + { numRuns: 1000 }, + ); + + const noiseArb = fc + .array( + fc.uint8Array({ minLength: 5 }).map((array) => Buffer.from(array)), + { minLength: 5 }, + ) + .noShrink(); + + testProp( + 'Will error on bad data', + [jsonMessagesArb, snippingPatternArb, noiseArb], + async (messages, snippattern, noise) => { + const parsedStream = jsonRpcStream(messages) + .pipeThrough(new BufferStreamToSnippedStream(snippattern)) // Imaginary internet here + .pipeThrough(new BufferStreamToNoisyStream(noise)) // Adding bad data to the stream + .pipeThrough(new JsonToJsonMessageStream()); // Converting back. + + await expect(AsyncIterable.as(parsedStream).toArray()).rejects.toThrow( + rpcErrors.ErrorRpcParse, + ); + }, + { numRuns: 1000 }, + ); +}); diff --git a/tests/rpc/utils.ts b/tests/rpc/utils.ts new file mode 100644 index 000000000..32d3950d5 --- /dev/null +++ b/tests/rpc/utils.ts @@ -0,0 +1,79 @@ +import type { + Transformer, + TransformerFlushCallback, + TransformerTransformCallback, +} from 'stream/web'; +import { TransformStream } from 'stream/web'; + +class BufferStreamToSnipped implements Transformer { + protected buffer = Buffer.alloc(0); + protected iteration = 0; + protected snippingPattern: Array; + + constructor(snippingPattern: Array) { + this.snippingPattern = snippingPattern; + } + + transform: TransformerTransformCallback = async ( + chunk, + controller, + ) => { + this.buffer = Buffer.concat([this.buffer, chunk]); + while (true) { + const snipAmount = + this.snippingPattern[this.iteration % this.snippingPattern.length]; + if (snipAmount > this.buffer.length) break; + this.iteration += 1; + const returnBuffer = this.buffer.subarray(0, snipAmount); + controller.enqueue(returnBuffer); + this.buffer = this.buffer.subarray(snipAmount); + } + }; + + flush: TransformerFlushCallback = (controller) => { + controller.enqueue(this.buffer); + }; +} + +/** + * This is used to convert regular chunks into randomly sized chunks based on + * a provided pattern. This is to replicate randomness introduced by packets + * splitting up the data. + */ +class BufferStreamToSnippedStream extends TransformStream { + constructor(snippingPattern: Array) { + super(new BufferStreamToSnipped(snippingPattern)); + } +} + +class BufferStreamToNoisy implements Transformer { + protected iteration = 0; + protected noise: Array; + + constructor(noise: Array) { + this.noise = noise; + } + + transform: TransformerTransformCallback = async ( + chunk, + controller, + ) => { + const noiseBuffer = this.noise[this.iteration % this.noise.length]; + const newBuffer = Buffer.from(Buffer.concat([chunk, noiseBuffer])); + controller.enqueue(newBuffer); + this.iteration += 1; + }; +} + +/** + * This is used to convert regular chunks into randomly sized chunks based on + * a provided pattern. This is to replicate randomness introduced by packets + * splitting up the data. + */ +class BufferStreamToNoisyStream extends TransformStream { + constructor(noise: Array) { + super(new BufferStreamToNoisy(noise)); + } +} + +export { BufferStreamToSnippedStream, BufferStreamToNoisyStream }; From 6418c58e2025f1363503f072ca2071ef5c2084dc Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Thu, 5 Jan 2023 19:31:36 +1100 Subject: [PATCH 02/44] tests: expanding message aritraries and testing parsers --- src/rpc/types.ts | 21 ++++---- src/rpc/utils.ts | 35 ++++++------ tests/rpc/utils.test.ts | 117 ++++++++++++++++++++++++++++++---------- 3 files changed, 116 insertions(+), 57 deletions(-) diff --git a/src/rpc/types.ts b/src/rpc/types.ts index b382bc8d5..0208ec601 100644 --- a/src/rpc/types.ts +++ b/src/rpc/types.ts @@ -3,7 +3,7 @@ import type { POJO } from '../types'; /** * This is the JSON RPC request object. this is the generic message type used for the RPC. */ -type JsonRpcRequest = { +type JsonRpcRequest = { type: 'JsonRpcRequest'; // A String specifying the version of the JSON-RPC protocol. MUST be exactly "2.0" jsonrpc: '2.0'; @@ -20,7 +20,7 @@ type JsonRpcRequest = { id: string | number | null; }; -type JsonRpcNotification = { +type JsonRpcNotification = { type: 'JsonRpcNotification'; // A String specifying the version of the JSON-RPC protocol. MUST be exactly "2.0" jsonrpc: '2.0'; @@ -33,7 +33,7 @@ type JsonRpcNotification = { params?: T; }; -type JsonRpcResponseResult = { +type JsonRpcResponseResult = { type: 'JsonRpcResponseResult'; // A String specifying the version of the JSON-RPC protocol. MUST be exactly "2.0". jsonrpc: '2.0'; @@ -48,7 +48,7 @@ type JsonRpcResponseResult = { id: string | number | null; }; -type JsonRpcResponseError = { +type JsonRpcResponseError = { type: 'JsonRpcResponseError'; // A String specifying the version of the JSON-RPC protocol. MUST be exactly "2.0". jsonrpc: '2.0'; @@ -75,7 +75,7 @@ type JsonRpcResponseError = { // -32603 Internal error Internal JSON-RPC error. // -32000 to -32099 -type JsonRpcError = { +type JsonRpcError = { // A Number that indicates the error type that occurred. // This MUST be an integer. code: number; @@ -88,11 +88,12 @@ type JsonRpcError = { data?: T; }; -type JsonRpcResponse = - | JsonRpcResponseResult - | JsonRpcResponseError; +type JsonRpcResponse< + T extends POJO | unknown = unknown, + K extends POJO | unknown = unknown, +> = JsonRpcResponseResult | JsonRpcResponseError; -type jsonRpcMessage = +type JsonRpcMessage = | JsonRpcRequest | JsonRpcNotification | JsonRpcResponseResult @@ -105,5 +106,5 @@ export type { JsonRpcResponseError, JsonRpcError, JsonRpcResponse, - jsonRpcMessage, + JsonRpcMessage, }; diff --git a/src/rpc/utils.ts b/src/rpc/utils.ts index 82ceab7de..b5f4d3c96 100644 --- a/src/rpc/utils.ts +++ b/src/rpc/utils.ts @@ -1,7 +1,7 @@ import type { Transformer, TransformerTransformCallback } from 'stream/web'; import type { JsonRpcError, - jsonRpcMessage, + JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponseError, @@ -9,9 +9,9 @@ import type { } from 'rpc/types'; import type { POJO } from '../types'; import { TransformStream } from 'stream/web'; -import * as utils from 'utils/index'; -import * as validationErrors from 'validation/errors'; import * as rpcErrors from './errors'; +import * as utils from '../utils'; +import * as validationErrors from '../validation/errors'; import { promise } from '../utils'; const jsonStreamParsers = require('@streamparser/json'); @@ -84,13 +84,6 @@ class JsonToJsonMessageStream extends TransformStream { } } -const messagetypes = [ - 'JsonRpcRequest', - 'JsonRpcNotification', - 'JsonRpcResponseResult', - 'JsonRpcResponseError', -]; - function parseJsonRpcRequest( message: unknown, ): JsonRpcRequest { @@ -123,7 +116,7 @@ function parseJsonRpcRequest( if ( typeof message.id !== 'string' && typeof message.id !== 'number' && - typeof message.id !== null + message.id !== null ) { throw new validationErrors.ErrorParse( '`id` property must be a string, number or null', @@ -198,7 +191,7 @@ function parseJsonRpcResponseResult( if ( typeof message.id !== 'string' && typeof message.id !== 'number' && - typeof message.id !== null + message.id !== null ) { throw new validationErrors.ErrorParse( '`id` property must be a string, number or null', @@ -239,7 +232,7 @@ function parseJsonRpcResponseError( if ( typeof message.id !== 'string' && typeof message.id !== 'number' && - typeof message.id !== null + message.id !== null ) { throw new validationErrors.ErrorParse( '`id` property must be a string, number or null', @@ -274,7 +267,7 @@ function parseJsonRpcError(message: unknown): JsonRpcError { function parseJsonRpcMessage( message: unknown, -): jsonRpcMessage { +): JsonRpcMessage { if (!utils.isObject(message)) { throw new validationErrors.ErrorParse('must be a JSON POJO'); } @@ -284,11 +277,6 @@ function parseJsonRpcMessage( if (typeof message.type !== 'string') { throw new validationErrors.ErrorParse('`type` property must be a string'); } - if (!(message.type in messagetypes)) { - throw new validationErrors.ErrorParse( - '`type` property must be a valid type', - ); - } if (!('jsonrpc' in message)) { throw new validationErrors.ErrorParse('`jsonrpc` property must be defined'); } @@ -313,4 +301,11 @@ function parseJsonRpcMessage( } } -export { JsonToJsonMessageStream }; +export { + JsonToJsonMessageStream, + parseJsonRpcRequest, + parseJsonRpcNotification, + parseJsonRpcResponseResult, + parseJsonRpcResponseError, + parseJsonRpcMessage, +}; diff --git a/tests/rpc/utils.test.ts b/tests/rpc/utils.test.ts index e0b767eb8..20db6a5ba 100644 --- a/tests/rpc/utils.test.ts +++ b/tests/rpc/utils.test.ts @@ -1,9 +1,15 @@ -import type { JsonRpcRequest } from '@/rpc/types'; +import type { + JsonRpcError, + JsonRpcMessage, + JsonRpcRequest, + JsonRpcResponseError, +} from '@/rpc/types'; import type { POJO } from '@/types'; +import type { JsonRpcNotification, JsonRpcResponseResult } from '@/rpc/types'; import { ReadableStream } from 'stream/web'; import { testProp, fc } from '@fast-check/jest'; import { AsyncIterableX as AsyncIterable } from 'ix/asynciterable'; -import { JsonToJsonMessageStream } from '@/rpc/utils'; +import * as rpcUtils from '@/rpc/utils'; import 'ix/add/asynciterable-operators/toarray'; import * as rpcErrors from '@/rpc/errors'; import { @@ -11,27 +17,6 @@ import { BufferStreamToSnippedStream, } from './utils'; -const jsonRpcRequestArb = ( - method: fc.Arbitrary = fc.string().map((value) => 'rpc-' + value), - params: fc.Arbitrary = fc.jsonValue(), - requireParams: boolean = false, -) => { - const requiredKeys: ('jsonrpc' | 'method' | 'params' | 'id')[] = requireParams - ? ['params', 'jsonrpc', 'method', 'id'] - : ['jsonrpc', 'method', 'id']; - return fc.record( - { - jsonrpc: fc.constant('2.0'), - method, - params: params.map((params) => JSON.parse(JSON.stringify(params))), - id: fc.integer({ min: 0 }), - }, - { - requiredKeys, - }, - ) as fc.Arbitrary>; -}; - describe('utils tests', () => { const jsonRpcStream = (messages: Array) => { return new ReadableStream({ @@ -47,12 +32,81 @@ describe('utils tests', () => { }); }; + const jsonRpcRequestArb = fc + .record( + { + type: fc.constant('JsonRpcRequest'), + jsonrpc: fc.constant('2.0'), + method: fc.string(), + params: fc.object(), + id: fc.oneof(fc.string(), fc.integer(), fc.constant(null)), + }, + { + requiredKeys: ['type', 'jsonrpc', 'method', 'id'], + }, + ) + .noShrink() as fc.Arbitrary; + + const jsonRpcNotificationArb = fc + .record( + { + type: fc.constant('JsonRpcNotification'), + jsonrpc: fc.constant('2.0'), + method: fc.string(), + params: fc.object(), + }, + { + requiredKeys: ['type', 'jsonrpc', 'method'], + }, + ) + .noShrink() as fc.Arbitrary; + + const jsonRpcResponseResultArb = fc + .record({ + type: fc.constant('JsonRpcResponseResult'), + jsonrpc: fc.constant('2.0'), + result: fc.object(), + id: fc.oneof(fc.string(), fc.integer(), fc.constant(null)), + }) + .noShrink() as fc.Arbitrary; + + const jsonRpcErrorArb = fc + .record( + { + code: fc.integer(), + message: fc.string(), + data: fc.object(), + }, + { + requiredKeys: ['code', 'message'], + }, + ) + .noShrink() as fc.Arbitrary; + + const jsonRpcResponseErrorArb = fc + .record({ + type: fc.constant('JsonRpcResponseError'), + jsonrpc: fc.constant('2.0'), + error: jsonRpcErrorArb, + id: fc.oneof(fc.string(), fc.integer(), fc.constant(null)), + }) + .noShrink() as fc.Arbitrary; + + const jsonRpcMessageArb = fc + .oneof( + jsonRpcRequestArb, + jsonRpcNotificationArb, + jsonRpcResponseResultArb, + jsonRpcResponseErrorArb, + ) + .noShrink() as fc.Arbitrary; + const snippingPatternArb = fc .array(fc.integer({ min: 1, max: 32 }), { minLength: 100, size: 'medium' }) .noShrink(); const jsonMessagesArb = fc - .array(jsonRpcRequestArb(), { minLength: 2 }) + .array(jsonRpcRequestArb, { minLength: 2 }) .noShrink(); testProp( @@ -60,7 +114,7 @@ describe('utils tests', () => { [jsonMessagesArb], async (messages) => { const parsedStream = jsonRpcStream(messages).pipeThrough( - new JsonToJsonMessageStream(), + new rpcUtils.JsonToJsonMessageStream(), ); // Converting back. const asd = await AsyncIterable.as(parsedStream).toArray(); @@ -75,7 +129,7 @@ describe('utils tests', () => { async (messages, snippattern) => { const parsedStream = jsonRpcStream(messages) .pipeThrough(new BufferStreamToSnippedStream(snippattern)) // Imaginary internet here - .pipeThrough(new JsonToJsonMessageStream()); // Converting back. + .pipeThrough(new rpcUtils.JsonToJsonMessageStream()); // Converting back. const asd = await AsyncIterable.as(parsedStream).toArray(); expect(asd).toStrictEqual(messages); @@ -97,7 +151,7 @@ describe('utils tests', () => { const parsedStream = jsonRpcStream(messages) .pipeThrough(new BufferStreamToSnippedStream(snippattern)) // Imaginary internet here .pipeThrough(new BufferStreamToNoisyStream(noise)) // Adding bad data to the stream - .pipeThrough(new JsonToJsonMessageStream()); // Converting back. + .pipeThrough(new rpcUtils.JsonToJsonMessageStream()); // Converting back. await expect(AsyncIterable.as(parsedStream).toArray()).rejects.toThrow( rpcErrors.ErrorRpcParse, @@ -105,4 +159,13 @@ describe('utils tests', () => { }, { numRuns: 1000 }, ); + + testProp( + 'can parse messages', + [jsonRpcMessageArb], + async (message) => { + rpcUtils.parseJsonRpcMessage(message); + }, + { numRuns: 1000 }, + ); }); From 8cf279ee2da8ae7f9aeae257ac0f35537d085e75 Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Fri, 6 Jan 2023 17:44:21 +1100 Subject: [PATCH 03/44] feat: rpc server --- src/rpc/Rpc.ts | 271 ++++++++++++++++++++++++++++++++++++++++ src/rpc/errors.ts | 41 +++++- src/rpc/types.ts | 48 +++++-- src/rpc/utils.ts | 67 ++++++---- tests/rpc/Rpc.test.ts | 188 ++++++++++++++++++++++++++++ tests/rpc/utils.test.ts | 131 +++---------------- tests/rpc/utils.ts | 166 +++++++++++++++++++++++- 7 files changed, 761 insertions(+), 151 deletions(-) create mode 100644 src/rpc/Rpc.ts create mode 100644 tests/rpc/Rpc.test.ts diff --git a/src/rpc/Rpc.ts b/src/rpc/Rpc.ts new file mode 100644 index 000000000..4eb7a22bd --- /dev/null +++ b/src/rpc/Rpc.ts @@ -0,0 +1,271 @@ +import type { + ClientStreamHandler, + DuplexStreamHandler, + JsonRpcError, + JsonRpcMessage, + JsonRpcResponseError, + JsonRpcResponseResult, + ServerStreamHandler, +} from './types'; +import type { ReadableWritablePair } from 'stream/web'; +import type { JSONValue, POJO } from '../types'; +import type { ConnectionInfo } from '../network/types'; +import type { UnaryHandler } from './types'; +import { ReadableStream } from 'stream/web'; +import { + CreateDestroyStartStop, + ready, +} from '@matrixai/async-init/dist/CreateDestroyStartStop'; +import Logger from '@matrixai/logger'; +import { PromiseCancellable } from '@matrixai/async-cancellable'; +import * as rpcErrors from './errors'; +import * as rpcUtils from './utils'; +import * as grpcUtils from '../grpc/utils'; + +// FIXME: Might need to be StartStop. Won't know for sure until it's used. +interface Rpc extends CreateDestroyStartStop {} +@CreateDestroyStartStop( + new rpcErrors.ErrorRpcRunning(), + new rpcErrors.ErrorRpcDestroyed(), +) +class Rpc { + static async createRpc({ + container, + logger = new Logger(this.name), + }: { + container: POJO; + logger?: Logger; + }): Promise { + logger.info(`Creating ${this.name}`); + const rpc = new this({ + container, + logger, + }); + await rpc.start(); + logger.info(`Created ${this.name}`); + return rpc; + } + + // Properties + protected container: POJO; + protected logger: Logger; + protected handlerMap: Map> = + new Map(); + private activeStreams: Set> = new Set(); + + public constructor({ + container, + logger, + }: { + container: POJO; + logger: Logger; + }) { + this.container = container; + this.logger = logger; + } + + public async start(): Promise { + this.logger.info(`Starting ${this.constructor.name}`); + this.logger.info(`Started ${this.constructor.name}`); + } + + public async stop(): Promise { + this.logger.info(`Stopping ${this.constructor.name}`); + // Stopping any active steams + const activeStreams = this.activeStreams; + for await (const [activeStream] of activeStreams.entries()) { + activeStream.cancel(new rpcErrors.ErrorRpcStopping()); + } + this.logger.info(`Stopped ${this.constructor.name}`); + } + + public async destroy(): Promise { + this.logger.info(`Destroying ${this.constructor.name}`); + this.logger.info(`Destroyed ${this.constructor.name}`); + } + + @ready(new rpcErrors.ErrorRpcNotRunning()) + public registerDuplexStreamHandler( + method: string, + handler: DuplexStreamHandler, + ) { + this.handlerMap.set(method, handler); + } + + @ready(new rpcErrors.ErrorRpcNotRunning()) + public registerUnaryHandler( + method: string, + handler: UnaryHandler, + ) { + const wrapperDuplex: DuplexStreamHandler = async function* ( + input, + container, + connectionInfo, + ctx, + ) { + let count = 0; + for await (const inputVal of input) { + if (count > 1) throw new rpcErrors.ErrorRpcProtocal(); + yield handler(inputVal, container, connectionInfo, ctx); + count += 1; + } + }; + this.handlerMap.set(method, wrapperDuplex); + } + + @ready(new rpcErrors.ErrorRpcNotRunning()) + public registerClientStreamHandler( + method: string, + handler: ClientStreamHandler, + ) { + const wrapperDuplex: DuplexStreamHandler = async function* ( + input, + container, + connectionInfo, + ctx, + ) { + let count = 0; + for await (const inputVal of input) { + if (count > 1) throw new rpcErrors.ErrorRpcProtocal(); + yield* handler(inputVal, container, connectionInfo, ctx); + count += 1; + } + }; + this.handlerMap.set(method, wrapperDuplex); + } + + @ready(new rpcErrors.ErrorRpcNotRunning()) + public registerServerStreamHandler( + method: string, + handler: ServerStreamHandler, + ) { + const wrapperDuplex: DuplexStreamHandler = async function* ( + input, + container, + connectionInfo, + ctx, + ) { + yield handler(input, container, connectionInfo, ctx); + }; + this.handlerMap.set(method, wrapperDuplex); + } + + @ready(new rpcErrors.ErrorRpcNotRunning()) + public handleStream( + streamPair: ReadableWritablePair, + connectionInfo: ConnectionInfo, + ) { + // This will take a buffer stream of json messages and set up service + // handling for it. + let resolve: (value: void | PromiseLike) => void; + const abortController = new AbortController(); + const handlerProm2: PromiseCancellable = new PromiseCancellable( + (resolve_) => { + resolve = resolve_; + }, + abortController, + ); + this.activeStreams.add(handlerProm2); + void handlerProm2.finally(() => this.activeStreams.delete(handlerProm2)); + // While ReadableStream can be converted to AsyncIterable, we want it as + // a generator. + const inputGen = async function* () { + const pojoStream = streamPair.readable.pipeThrough( + new rpcUtils.JsonToJsonMessageStream(), + ); + for await (const dataMessage of pojoStream) { + // Filtering for request and notification messages + if ( + dataMessage.type === 'JsonRpcRequest' || + dataMessage.type === 'JsonRpcNotification' + ) { + yield dataMessage; + } + } + }; + const container = this.container; + const handlerMap = this.handlerMap; + const ctx = { signal: abortController.signal }; + const outputGen = async function* (): AsyncGenerator { + // Step 1, authentication and establishment + // read the first message, lets assume the first message is always leading + // metadata. + const input = inputGen(); + if (ctx.signal.aborted) throw ctx.signal.reason; + const leadingMetadataMessage = await input.next(); + if (leadingMetadataMessage.done === true) { + throw Error('TMP Stream closed early'); + } + const method = leadingMetadataMessage.value.method; + const _metadata = leadingMetadataMessage.value.params; + const dataGen = async function* () { + for await (const data of input) { + yield data.params as JSONValue; + } + }; + // TODO: validation on metadata + const handler = handlerMap.get(method); + if (handler == null) { + // Failed to find handler, this is an error. We should respond with + // an error message. + throw new rpcErrors.ErrorRpcHandlerMissing( + `No handler registered for method: ${method}`, + ); + } + if (ctx.signal.aborted) throw ctx.signal.reason; + try { + for await (const response of handler( + dataGen(), + container, + connectionInfo, + ctx, + )) { + const responseMessage: JsonRpcResponseResult = { + type: 'JsonRpcResponseResult', + jsonrpc: '2.0', + result: response, + id: null, + }; + yield responseMessage; + } + } catch (e) { + // This would be an error from the handler or the streams. We should + // catch this and send an error message back through the stream. + const rpcError: JsonRpcError = { + code: e.exitCode, + message: e.description, + data: grpcUtils.fromError(e), + }; + const rpcErrorMessage: JsonRpcResponseError = { + type: 'JsonRpcResponseError', + jsonrpc: '2.0', + error: rpcError, + id: null, + }; + yield rpcErrorMessage; + } + resolve(); + }; + + const outputGenerator = outputGen(); + + const outputStream = new ReadableStream({ + pull: async (controller) => { + const { value, done } = await outputGenerator.next(); + if (done) { + controller.close(); + return; + } + controller.enqueue(value); + }, + cancel: async (reason) => { + await outputGenerator.throw(reason); + }, + }); + void outputStream + .pipeThrough(new rpcUtils.JsonMessageToJsonStream()) + .pipeTo(streamPair.writable); + } +} + +export default Rpc; diff --git a/src/rpc/errors.ts b/src/rpc/errors.ts index 31bd028ec..13549aeb4 100644 --- a/src/rpc/errors.ts +++ b/src/rpc/errors.ts @@ -2,9 +2,48 @@ import { ErrorPolykey, sysexits } from '../errors'; class ErrorRpc extends ErrorPolykey {} +class ErrorRpcRunning extends ErrorRpc { + static description = 'Rpc is running'; + exitCode = sysexits.USAGE; +} + +class ErrorRpcDestroyed extends ErrorRpc { + static description = 'Rpc is destroyed'; + exitCode = sysexits.USAGE; +} + +class ErrorRpcNotRunning extends ErrorRpc { + static description = 'Rpc is not running'; + exitCode = sysexits.USAGE; +} + +class ErrorRpcStopping extends ErrorRpc { + static description = 'Rpc is stopping'; + exitCode = sysexits.USAGE; +} + class ErrorRpcParse extends ErrorRpc { static description = 'Failed to parse Buffer stream'; exitCode = sysexits.SOFTWARE; } -export { ErrorRpc, ErrorRpcParse }; +class ErrorRpcHandlerMissing extends ErrorRpc { + static description = 'No handler was registered for the given method'; + exitCode = sysexits.USAGE; +} + +class ErrorRpcProtocal extends ErrorRpc { + static description = 'Unexpected behaviour during communication'; + exitCode = sysexits.PROTOCOL; +} + +export { + ErrorRpc, + ErrorRpcRunning, + ErrorRpcDestroyed, + ErrorRpcNotRunning, + ErrorRpcStopping, + ErrorRpcParse, + ErrorRpcHandlerMissing, + ErrorRpcProtocal, +}; diff --git a/src/rpc/types.ts b/src/rpc/types.ts index 0208ec601..71f5ebb21 100644 --- a/src/rpc/types.ts +++ b/src/rpc/types.ts @@ -1,9 +1,11 @@ -import type { POJO } from '../types'; +import type { JSONValue, POJO } from '../types'; +import type { ConnectionInfo } from '../network/types'; +import type { ContextCancellable } from '../contexts/types'; /** * This is the JSON RPC request object. this is the generic message type used for the RPC. */ -type JsonRpcRequest = { +type JsonRpcRequest = { type: 'JsonRpcRequest'; // A String specifying the version of the JSON-RPC protocol. MUST be exactly "2.0" jsonrpc: '2.0'; @@ -20,7 +22,7 @@ type JsonRpcRequest = { id: string | number | null; }; -type JsonRpcNotification = { +type JsonRpcNotification = { type: 'JsonRpcNotification'; // A String specifying the version of the JSON-RPC protocol. MUST be exactly "2.0" jsonrpc: '2.0'; @@ -33,7 +35,7 @@ type JsonRpcNotification = { params?: T; }; -type JsonRpcResponseResult = { +type JsonRpcResponseResult = { type: 'JsonRpcResponseResult'; // A String specifying the version of the JSON-RPC protocol. MUST be exactly "2.0". jsonrpc: '2.0'; @@ -48,7 +50,7 @@ type JsonRpcResponseResult = { id: string | number | null; }; -type JsonRpcResponseError = { +type JsonRpcResponseError = { type: 'JsonRpcResponseError'; // A String specifying the version of the JSON-RPC protocol. MUST be exactly "2.0". jsonrpc: '2.0'; @@ -75,7 +77,7 @@ type JsonRpcResponseError = { // -32603 Internal error Internal JSON-RPC error. // -32000 to -32099 -type JsonRpcError = { +type JsonRpcError = { // A Number that indicates the error type that occurred. // This MUST be an integer. code: number; @@ -89,16 +91,40 @@ type JsonRpcError = { }; type JsonRpcResponse< - T extends POJO | unknown = unknown, - K extends POJO | unknown = unknown, + T extends JSONValue | unknown = unknown, + K extends JSONValue | unknown = unknown, > = JsonRpcResponseResult | JsonRpcResponseError; -type JsonRpcMessage = +type JsonRpcMessage = | JsonRpcRequest | JsonRpcNotification | JsonRpcResponseResult | JsonRpcResponseError; +// Handler types +type Handler = ( + input: I, + container: POJO, + connectionInfo: ConnectionInfo, + ctx: ContextCancellable, +) => O; +type DuplexStreamHandler = Handler< + AsyncGenerator, + AsyncGenerator +>; +type ClientStreamHandler = Handler< + I, + AsyncGenerator +>; +type ServerStreamHandler = Handler< + AsyncGenerator, + Promise +>; +type UnaryHandler = Handler< + I, + Promise +>; + export type { JsonRpcRequest, JsonRpcNotification, @@ -107,4 +133,8 @@ export type { JsonRpcError, JsonRpcResponse, JsonRpcMessage, + DuplexStreamHandler, + ClientStreamHandler, + ServerStreamHandler, + UnaryHandler, }; diff --git a/src/rpc/utils.ts b/src/rpc/utils.ts index b5f4d3c96..b2b380f47 100644 --- a/src/rpc/utils.ts +++ b/src/rpc/utils.ts @@ -7,7 +7,7 @@ import type { JsonRpcResponseError, JsonRpcResponseResult, } from 'rpc/types'; -import type { POJO } from '../types'; +import type { JSONValue } from '../types'; import { TransformStream } from 'stream/web'; import * as rpcErrors from './errors'; import * as utils from '../utils'; @@ -15,7 +15,7 @@ import * as validationErrors from '../validation/errors'; import { promise } from '../utils'; const jsonStreamParsers = require('@streamparser/json'); -class JsonToJsonMessage implements Transformer { +class JsonToJsonMessage implements Transformer { protected buffer = Buffer.alloc(0); /** @@ -58,7 +58,7 @@ class JsonToJsonMessage implements Transformer { return await foundOffset.p; } - transform: TransformerTransformCallback = async ( + transform: TransformerTransformCallback = async ( chunk, controller, ) => { @@ -68,7 +68,9 @@ class JsonToJsonMessage implements Transformer { if (index <= 0) break; const outputBuffer = this.buffer.subarray(0, index + 1); try { - controller.enqueue(JSON.parse(outputBuffer.toString('utf-8'))); + const jsonObject = JSON.parse(outputBuffer.toString('utf-8')); + const jsonRpcMessage = parseJsonRpcMessage(jsonObject); + controller.enqueue(jsonRpcMessage); } catch (e) { throw new rpcErrors.ErrorRpcParse(undefined, { cause: e }); } @@ -78,13 +80,29 @@ class JsonToJsonMessage implements Transformer { } // TODO: rename to something more descriptive? -class JsonToJsonMessageStream extends TransformStream { +class JsonToJsonMessageStream extends TransformStream { constructor() { super(new JsonToJsonMessage()); } } -function parseJsonRpcRequest( +class JsonMessageToJson implements Transformer { + transform: TransformerTransformCallback = async ( + chunk, + controller, + ) => { + controller.enqueue(Buffer.from(JSON.stringify(chunk))); + }; +} + +// TODO: rename to something more descriptive? +class JsonMessageToJsonStream extends TransformStream { + constructor() { + super(new JsonMessageToJson()); + } +} + +function parseJsonRpcRequest( message: unknown, ): JsonRpcRequest { if (!utils.isObject(message)) { @@ -107,9 +125,9 @@ function parseJsonRpcRequest( if (typeof message.method !== 'string') { throw new validationErrors.ErrorParse('`method` property must be a string'); } - if ('params' in message && !utils.isObject(message.params)) { - throw new validationErrors.ErrorParse('`params` property must be a POJO'); - } + // If ('params' in message && !utils.isObject(message.params)) { + // throw new validationErrors.ErrorParse('`params` property must be a POJO'); + // } if (!('id' in message)) { throw new validationErrors.ErrorParse('`id` property must be defined'); } @@ -125,7 +143,7 @@ function parseJsonRpcRequest( return message as JsonRpcRequest; } -function parseJsonRpcNotification( +function parseJsonRpcNotification( message: unknown, ): JsonRpcNotification { if (!utils.isObject(message)) { @@ -148,16 +166,16 @@ function parseJsonRpcNotification( if (typeof message.method !== 'string') { throw new validationErrors.ErrorParse('`method` property must be a string'); } - if ('params' in message && !utils.isObject(message.params)) { - throw new validationErrors.ErrorParse('`params` property must be a POJO'); - } + // If ('params' in message && !utils.isObject(message.params)) { + // throw new validationErrors.ErrorParse('`params` property must be a POJO'); + // } if ('id' in message) { throw new validationErrors.ErrorParse('`id` property must not be defined'); } return message as JsonRpcNotification; } -function parseJsonRpcResponseResult( +function parseJsonRpcResponseResult( message: unknown, ): JsonRpcResponseResult { if (!utils.isObject(message)) { @@ -182,9 +200,9 @@ function parseJsonRpcResponseResult( '`error` property must not be defined', ); } - if (!utils.isObject(message.result)) { - throw new validationErrors.ErrorParse('`result` property must be a POJO'); - } + // If (!utils.isObject(message.result)) { + // throw new validationErrors.ErrorParse('`result` property must be a POJO'); + // } if (!('id' in message)) { throw new validationErrors.ErrorParse('`id` property must be defined'); } @@ -200,7 +218,7 @@ function parseJsonRpcResponseResult( return message as JsonRpcResponseResult; } -function parseJsonRpcResponseError( +function parseJsonRpcResponseError( message: unknown, ): JsonRpcResponseError { if (!utils.isObject(message)) { @@ -241,7 +259,9 @@ function parseJsonRpcResponseError( return message as JsonRpcResponseError; } -function parseJsonRpcError(message: unknown): JsonRpcError { +function parseJsonRpcError( + message: unknown, +): JsonRpcError { if (!utils.isObject(message)) { throw new validationErrors.ErrorParse('must be a JSON POJO'); } @@ -259,13 +279,13 @@ function parseJsonRpcError(message: unknown): JsonRpcError { '`message` property must be a string', ); } - if ('data' in message && !utils.isObject(message.data)) { - throw new validationErrors.ErrorParse('`data` property must be a POJO'); - } + // If ('data' in message && !utils.isObject(message.data)) { + // throw new validationErrors.ErrorParse('`data` property must be a POJO'); + // } return message as JsonRpcError; } -function parseJsonRpcMessage( +function parseJsonRpcMessage( message: unknown, ): JsonRpcMessage { if (!utils.isObject(message)) { @@ -303,6 +323,7 @@ function parseJsonRpcMessage( export { JsonToJsonMessageStream, + JsonMessageToJsonStream, parseJsonRpcRequest, parseJsonRpcNotification, parseJsonRpcResponseResult, diff --git a/tests/rpc/Rpc.test.ts b/tests/rpc/Rpc.test.ts new file mode 100644 index 000000000..80cdc823b --- /dev/null +++ b/tests/rpc/Rpc.test.ts @@ -0,0 +1,188 @@ +import type { DuplexStreamHandler, JsonRpcMessage } from '@/rpc/types'; +import type { JSONValue } from '@/types'; +import type { ConnectionInfo, Host, Port } from '@/network/types'; +import type { NodeId } from '@/ids'; +import type { ReadableWritablePair } from 'stream/web'; +import { testProp, fc } from '@fast-check/jest'; +import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; +import Rpc from '@/rpc/Rpc'; +import * as rpcErrors from '@/rpc/errors'; +import * as rpcTestUtils from './utils'; + +describe(`${Rpc.name}`, () => { + const logger = new Logger(`${Rpc.name} Test`, LogLevel.WARN, [ + new StreamHandler(), + ]); + + const methodName = 'testMethod'; + const specificMessageArb = fc + .array(rpcTestUtils.jsonRpcRequestArb(fc.constant(methodName)), { + minLength: 5, + }) + .noShrink(); + + testProp('can stream data', [specificMessageArb], async (messages) => { + const stream = rpcTestUtils.jsonRpcStream(messages); + const container = {}; + const rpc = await Rpc.createRpc({ container, logger }); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const readWriteStream: ReadableWritablePair = { + readable: stream, + writable: outputStream, + }; + + const duplexHandler: DuplexStreamHandler = + async function* (input, _container, _connectionInfo, _ctx) { + for await (const val of input) { + yield val; + } + }; + + rpc.registerDuplexStreamHandler(methodName, duplexHandler); + rpc.handleStream(readWriteStream, {} as ConnectionInfo); + await outputResult; + }); + + testProp( + 'Handler is provided with container', + [specificMessageArb], + async (messages) => { + const stream = rpcTestUtils.jsonRpcStream(messages); + const container = { + a: Symbol('a'), + B: Symbol('b'), + C: Symbol('c'), + }; + const rpc = await Rpc.createRpc({ container, logger }); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const readWriteStream: ReadableWritablePair = { + readable: stream, + writable: outputStream, + }; + + const duplexHandler: DuplexStreamHandler = + async function* (input, container_, _connectionInfo, _ctx) { + expect(container_).toBe(container); + for await (const val of input) { + yield val; + } + }; + + rpc.registerDuplexStreamHandler(methodName, duplexHandler); + rpc.handleStream(readWriteStream, {} as ConnectionInfo); + await outputResult; + }, + ); + + testProp( + 'Handler is provided with connectionInfo', + [specificMessageArb], + async (messages) => { + const stream = rpcTestUtils.jsonRpcStream(messages); + const connectionInfo: ConnectionInfo = { + localHost: 'hostA' as Host, + localPort: 12341 as Port, + remoteCertificates: [], + remoteHost: 'hostA' as Host, + remoteNodeId: 'asd' as unknown as NodeId, + remotePort: 12341 as Port, + }; + const container = {}; + const rpc = await Rpc.createRpc({ container, logger }); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const readWriteStream: ReadableWritablePair = { + readable: stream, + writable: outputStream, + }; + + const duplexHandler: DuplexStreamHandler = + async function* (input, _container, connectionInfo_, _ctx) { + expect(connectionInfo_).toBe(connectionInfo); + for await (const val of input) { + yield val; + } + }; + rpc.registerDuplexStreamHandler(methodName, duplexHandler); + rpc.handleStream(readWriteStream, {} as ConnectionInfo); + await outputResult; + }, + ); + + // Problem with the tap stream. It seems to block the whole stream. + // If I don't pipe the tap to the output we actually iterate over some data. + testProp.skip( + 'Handler can be aborted', + [specificMessageArb], + async (messages) => { + const stream = rpcTestUtils.jsonRpcStream(messages); + const container = {}; + const rpc = await Rpc.createRpc({ container, logger }); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + let thing; + let lastMessage: JsonRpcMessage | undefined; + const tapStream = new rpcTestUtils.TapStream( + async (_, iteration) => { + if (iteration === 2) { + // @ts-ignore: kidnap private property + const activeStreams = rpc.activeStreams.values(); + for (const activeStream of activeStreams) { + thing = activeStream; + activeStream.cancel(new rpcErrors.ErrorRpcStopping()); + } + } + }, + ); + await tapStream.readable.pipeTo(outputStream); + const readWriteStream: ReadableWritablePair = { + readable: stream, + writable: tapStream.writable, + }; + + const duplexHandler: DuplexStreamHandler = + async function* (input, _container, _connectionInfo, ctx) { + for await (const val of input) { + if (ctx.signal.aborted) throw ctx.signal.reason; + yield val; + } + }; + rpc.registerDuplexStreamHandler(methodName, duplexHandler); + rpc.handleStream(readWriteStream, {} as ConnectionInfo); + await outputResult; + await expect(thing).toResolve(); + // Last message should be an error message + expect(lastMessage).toBeDefined(); + expect(lastMessage?.type).toBe('JsonRpcResponseError'); + }, + ); + + testProp('Handler yields nothing', [specificMessageArb], async (messages) => { + const stream = rpcTestUtils.jsonRpcStream(messages); + const container = {}; + const rpc = await Rpc.createRpc({ container, logger }); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const readWriteStream: ReadableWritablePair = { + readable: stream, + writable: outputStream, + }; + + const duplexHandler: DuplexStreamHandler = + async function* (input, _container, _connectionInfo, _ctx) { + for await (const _ of input) { + // Do nothing, just consume + } + }; + + rpc.registerDuplexStreamHandler(methodName, duplexHandler); + rpc.handleStream(readWriteStream, {} as ConnectionInfo); + await outputResult; + // We're just expecting no errors + }); + + // TODO: + // - Test for each type of handler + // - duplex + // - client stream + // - server stream + // - unary + // - Test odd conditions for handlers, like extra messages where 1 is expected. +}); diff --git a/tests/rpc/utils.test.ts b/tests/rpc/utils.test.ts index 20db6a5ba..75276b49c 100644 --- a/tests/rpc/utils.test.ts +++ b/tests/rpc/utils.test.ts @@ -1,121 +1,18 @@ -import type { - JsonRpcError, - JsonRpcMessage, - JsonRpcRequest, - JsonRpcResponseError, -} from '@/rpc/types'; -import type { POJO } from '@/types'; -import type { JsonRpcNotification, JsonRpcResponseResult } from '@/rpc/types'; -import { ReadableStream } from 'stream/web'; import { testProp, fc } from '@fast-check/jest'; import { AsyncIterableX as AsyncIterable } from 'ix/asynciterable'; import * as rpcUtils from '@/rpc/utils'; import 'ix/add/asynciterable-operators/toarray'; import * as rpcErrors from '@/rpc/errors'; -import { - BufferStreamToNoisyStream, - BufferStreamToSnippedStream, -} from './utils'; +import * as rpcTestUtils from './utils'; describe('utils tests', () => { - const jsonRpcStream = (messages: Array) => { - return new ReadableStream({ - async start(controller) { - for (const arrayElement of messages) { - // Controller.enqueue(arrayElement) - controller.enqueue( - Buffer.from(JSON.stringify(arrayElement), 'utf-8'), - ); - } - controller.close(); - }, - }); - }; - - const jsonRpcRequestArb = fc - .record( - { - type: fc.constant('JsonRpcRequest'), - jsonrpc: fc.constant('2.0'), - method: fc.string(), - params: fc.object(), - id: fc.oneof(fc.string(), fc.integer(), fc.constant(null)), - }, - { - requiredKeys: ['type', 'jsonrpc', 'method', 'id'], - }, - ) - .noShrink() as fc.Arbitrary; - - const jsonRpcNotificationArb = fc - .record( - { - type: fc.constant('JsonRpcNotification'), - jsonrpc: fc.constant('2.0'), - method: fc.string(), - params: fc.object(), - }, - { - requiredKeys: ['type', 'jsonrpc', 'method'], - }, - ) - .noShrink() as fc.Arbitrary; - - const jsonRpcResponseResultArb = fc - .record({ - type: fc.constant('JsonRpcResponseResult'), - jsonrpc: fc.constant('2.0'), - result: fc.object(), - id: fc.oneof(fc.string(), fc.integer(), fc.constant(null)), - }) - .noShrink() as fc.Arbitrary; - - const jsonRpcErrorArb = fc - .record( - { - code: fc.integer(), - message: fc.string(), - data: fc.object(), - }, - { - requiredKeys: ['code', 'message'], - }, - ) - .noShrink() as fc.Arbitrary; - - const jsonRpcResponseErrorArb = fc - .record({ - type: fc.constant('JsonRpcResponseError'), - jsonrpc: fc.constant('2.0'), - error: jsonRpcErrorArb, - id: fc.oneof(fc.string(), fc.integer(), fc.constant(null)), - }) - .noShrink() as fc.Arbitrary; - - const jsonRpcMessageArb = fc - .oneof( - jsonRpcRequestArb, - jsonRpcNotificationArb, - jsonRpcResponseResultArb, - jsonRpcResponseErrorArb, - ) - .noShrink() as fc.Arbitrary; - - const snippingPatternArb = fc - .array(fc.integer({ min: 1, max: 32 }), { minLength: 100, size: 'medium' }) - .noShrink(); - - const jsonMessagesArb = fc - .array(jsonRpcRequestArb, { minLength: 2 }) - .noShrink(); - testProp( 'can parse json stream', - [jsonMessagesArb], + [rpcTestUtils.jsonMessagesArb], async (messages) => { - const parsedStream = jsonRpcStream(messages).pipeThrough( - new rpcUtils.JsonToJsonMessageStream(), - ); // Converting back. + const parsedStream = rpcTestUtils + .jsonRpcStream(messages) + .pipeThrough(new rpcUtils.JsonToJsonMessageStream()); // Converting back. const asd = await AsyncIterable.as(parsedStream).toArray(); expect(asd).toEqual(messages); @@ -125,10 +22,11 @@ describe('utils tests', () => { testProp( 'can parse json stream with random chunk sizes', - [jsonMessagesArb, snippingPatternArb], + [rpcTestUtils.jsonMessagesArb, rpcTestUtils.snippingPatternArb], async (messages, snippattern) => { - const parsedStream = jsonRpcStream(messages) - .pipeThrough(new BufferStreamToSnippedStream(snippattern)) // Imaginary internet here + const parsedStream = rpcTestUtils + .jsonRpcStream(messages) + .pipeThrough(new rpcTestUtils.BufferStreamToSnippedStream(snippattern)) // Imaginary internet here .pipeThrough(new rpcUtils.JsonToJsonMessageStream()); // Converting back. const asd = await AsyncIterable.as(parsedStream).toArray(); @@ -146,11 +44,12 @@ describe('utils tests', () => { testProp( 'Will error on bad data', - [jsonMessagesArb, snippingPatternArb, noiseArb], + [rpcTestUtils.jsonMessagesArb, rpcTestUtils.snippingPatternArb, noiseArb], async (messages, snippattern, noise) => { - const parsedStream = jsonRpcStream(messages) - .pipeThrough(new BufferStreamToSnippedStream(snippattern)) // Imaginary internet here - .pipeThrough(new BufferStreamToNoisyStream(noise)) // Adding bad data to the stream + const parsedStream = rpcTestUtils + .jsonRpcStream(messages) + .pipeThrough(new rpcTestUtils.BufferStreamToSnippedStream(snippattern)) // Imaginary internet here + .pipeThrough(new rpcTestUtils.BufferStreamToNoisyStream(noise)) // Adding bad data to the stream .pipeThrough(new rpcUtils.JsonToJsonMessageStream()); // Converting back. await expect(AsyncIterable.as(parsedStream).toArray()).rejects.toThrow( @@ -162,7 +61,7 @@ describe('utils tests', () => { testProp( 'can parse messages', - [jsonRpcMessageArb], + [rpcTestUtils.jsonRpcMessageArb], async (message) => { rpcUtils.parseJsonRpcMessage(message); }, diff --git a/tests/rpc/utils.ts b/tests/rpc/utils.ts index 32d3950d5..06e16ed3a 100644 --- a/tests/rpc/utils.ts +++ b/tests/rpc/utils.ts @@ -3,7 +3,19 @@ import type { TransformerFlushCallback, TransformerTransformCallback, } from 'stream/web'; -import { TransformStream } from 'stream/web'; +import type { POJO } from '@/types'; +import type { + JsonRpcError, + JsonRpcMessage, + JsonRpcNotification, + JsonRpcRequest, + JsonRpcResponseError, + JsonRpcResponseResult, +} from '@/rpc/types'; +import type { JsonValue } from 'fast-check'; +import { ReadableStream, WritableStream, TransformStream } from 'stream/web'; +import { fc } from '@fast-check/jest'; +import * as utils from '@/utils'; class BufferStreamToSnipped implements Transformer { protected buffer = Buffer.alloc(0); @@ -76,4 +88,154 @@ class BufferStreamToNoisyStream extends TransformStream { } } -export { BufferStreamToSnippedStream, BufferStreamToNoisyStream }; +const jsonRpcStream = (messages: Array) => { + return new ReadableStream({ + async start(controller) { + for (const arrayElement of messages) { + // Controller.enqueue(arrayElement) + controller.enqueue(Buffer.from(JSON.stringify(arrayElement), 'utf-8')); + } + controller.close(); + }, + }); +}; + +const jsonRpcRequestArb = ( + method: fc.Arbitrary = fc.string(), + params: fc.Arbitrary = fc.jsonValue(), +) => + fc + .record( + { + type: fc.constant('JsonRpcRequest'), + jsonrpc: fc.constant('2.0'), + method: method, + params: params, + id: fc.oneof(fc.string(), fc.integer(), fc.constant(null)), + }, + { + requiredKeys: ['type', 'jsonrpc', 'method', 'id'], + }, + ) + .noShrink() as fc.Arbitrary; + +const jsonRpcNotificationArb = fc + .record( + { + type: fc.constant('JsonRpcNotification'), + jsonrpc: fc.constant('2.0'), + method: fc.string(), + params: fc.jsonValue(), + }, + { + requiredKeys: ['type', 'jsonrpc', 'method'], + }, + ) + .noShrink() as fc.Arbitrary; + +const jsonRpcResponseResultArb = fc + .record({ + type: fc.constant('JsonRpcResponseResult'), + jsonrpc: fc.constant('2.0'), + result: fc.jsonValue(), + id: fc.oneof(fc.string(), fc.integer(), fc.constant(null)), + }) + .noShrink() as fc.Arbitrary; + +const jsonRpcErrorArb = fc + .record( + { + code: fc.integer(), + message: fc.string(), + data: fc.jsonValue(), + }, + { + requiredKeys: ['code', 'message'], + }, + ) + .noShrink() as fc.Arbitrary; + +const jsonRpcResponseErrorArb = fc + .record({ + type: fc.constant('JsonRpcResponseError'), + jsonrpc: fc.constant('2.0'), + error: jsonRpcErrorArb, + id: fc.oneof(fc.string(), fc.integer(), fc.constant(null)), + }) + .noShrink() as fc.Arbitrary; + +const jsonRpcMessageArb = fc + .oneof( + jsonRpcRequestArb(), + jsonRpcNotificationArb, + jsonRpcResponseResultArb, + jsonRpcResponseErrorArb, + ) + .noShrink() as fc.Arbitrary; + +const snippingPatternArb = fc + .array(fc.integer({ min: 1, max: 32 }), { minLength: 100, size: 'medium' }) + .noShrink(); + +const jsonMessagesArb = fc + .array(jsonRpcRequestArb(), { minLength: 2 }) + .noShrink(); + +function streamToArray(): [Promise>, WritableStream] { + const outputArray: Array = []; + const result = utils.promise>(); + const outputStream = new WritableStream({ + write: (chunk) => { + outputArray.push(chunk); + }, + close: () => { + result.resolveP(outputArray); + }, + abort: (reason) => { + result.rejectP(reason); + }, + }); + return [result.p, outputStream]; +} + +class Tap implements Transformer { + protected iteration = 0; + protected tapIterator; + + constructor(tapIterator: (chunk: T, iteration: number) => Promise) { + this.tapIterator = tapIterator; + } + + transform: TransformerTransformCallback = async (chunk, controller) => { + await this.tapIterator(chunk, this.iteration); + controller.enqueue(chunk); + this.iteration += 1; + }; +} + +/** + * This is used to convert regular chunks into randomly sized chunks based on + * a provided pattern. This is to replicate randomness introduced by packets + * splitting up the data. + */ +class TapStream extends TransformStream { + constructor(tapIterator: (chunk: T, iteration: number) => Promise) { + super(new Tap(tapIterator)); + } +} + +export { + BufferStreamToSnippedStream, + BufferStreamToNoisyStream, + jsonRpcStream, + jsonRpcRequestArb, + jsonRpcNotificationArb, + jsonRpcResponseResultArb, + jsonRpcErrorArb, + jsonRpcResponseErrorArb, + jsonRpcMessageArb, + snippingPatternArb, + jsonMessagesArb, + streamToArray, + TapStream, +}; From 6405186010d2b18643bff489de6c1794888f8d6e Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Tue, 10 Jan 2023 19:44:45 +1100 Subject: [PATCH 04/44] fix: renaming Rpc to RPC --- src/{rpc/Rpc.ts => RPC/RPC.ts} | 17 +++++++---------- src/{rpc => RPC}/errors.ts | 0 src/{rpc => RPC}/index.ts | 0 src/{rpc => RPC}/types.ts | 0 src/{rpc => RPC}/utils.ts | 2 +- tests/{rpc => RPC}/Rpc.test.ts | 23 ++++++++++++----------- tests/{rpc => RPC}/utils.test.ts | 0 tests/{rpc => RPC}/utils.ts | 2 +- 8 files changed, 21 insertions(+), 23 deletions(-) rename src/{rpc/Rpc.ts => RPC/RPC.ts} (96%) rename src/{rpc => RPC}/errors.ts (100%) rename src/{rpc => RPC}/index.ts (100%) rename src/{rpc => RPC}/types.ts (100%) rename src/{rpc => RPC}/utils.ts (99%) rename tests/{rpc => RPC}/Rpc.test.ts (90%) rename tests/{rpc => RPC}/utils.test.ts (100%) rename tests/{rpc => RPC}/utils.ts (99%) diff --git a/src/rpc/Rpc.ts b/src/RPC/RPC.ts similarity index 96% rename from src/rpc/Rpc.ts rename to src/RPC/RPC.ts index 4eb7a22bd..6c0315ca9 100644 --- a/src/rpc/Rpc.ts +++ b/src/RPC/RPC.ts @@ -23,19 +23,19 @@ import * as rpcUtils from './utils'; import * as grpcUtils from '../grpc/utils'; // FIXME: Might need to be StartStop. Won't know for sure until it's used. -interface Rpc extends CreateDestroyStartStop {} +interface RPC extends CreateDestroyStartStop {} @CreateDestroyStartStop( new rpcErrors.ErrorRpcRunning(), new rpcErrors.ErrorRpcDestroyed(), ) -class Rpc { +class RPC { static async createRpc({ container, logger = new Logger(this.name), }: { container: POJO; logger?: Logger; - }): Promise { + }): Promise { logger.info(`Creating ${this.name}`); const rpc = new this({ container, @@ -103,11 +103,9 @@ class Rpc { connectionInfo, ctx, ) { - let count = 0; for await (const inputVal of input) { - if (count > 1) throw new rpcErrors.ErrorRpcProtocal(); yield handler(inputVal, container, connectionInfo, ctx); - count += 1; + break; } }; this.handlerMap.set(method, wrapperDuplex); @@ -124,11 +122,9 @@ class Rpc { connectionInfo, ctx, ) { - let count = 0; for await (const inputVal of input) { - if (count > 1) throw new rpcErrors.ErrorRpcProtocal(); yield* handler(inputVal, container, connectionInfo, ctx); - count += 1; + break; } }; this.handlerMap.set(method, wrapperDuplex); @@ -242,6 +238,7 @@ class Rpc { error: rpcError, id: null, }; + // TODO: catch this and emit error in the event emitter yield rpcErrorMessage; } resolve(); @@ -268,4 +265,4 @@ class Rpc { } } -export default Rpc; +export default RPC; diff --git a/src/rpc/errors.ts b/src/RPC/errors.ts similarity index 100% rename from src/rpc/errors.ts rename to src/RPC/errors.ts diff --git a/src/rpc/index.ts b/src/RPC/index.ts similarity index 100% rename from src/rpc/index.ts rename to src/RPC/index.ts diff --git a/src/rpc/types.ts b/src/RPC/types.ts similarity index 100% rename from src/rpc/types.ts rename to src/RPC/types.ts diff --git a/src/rpc/utils.ts b/src/RPC/utils.ts similarity index 99% rename from src/rpc/utils.ts rename to src/RPC/utils.ts index b2b380f47..4dc1299b7 100644 --- a/src/rpc/utils.ts +++ b/src/RPC/utils.ts @@ -6,7 +6,7 @@ import type { JsonRpcRequest, JsonRpcResponseError, JsonRpcResponseResult, -} from 'rpc/types'; +} from 'RPC/types'; import type { JSONValue } from '../types'; import { TransformStream } from 'stream/web'; import * as rpcErrors from './errors'; diff --git a/tests/rpc/Rpc.test.ts b/tests/RPC/Rpc.test.ts similarity index 90% rename from tests/rpc/Rpc.test.ts rename to tests/RPC/Rpc.test.ts index 80cdc823b..7f0a42eda 100644 --- a/tests/rpc/Rpc.test.ts +++ b/tests/RPC/Rpc.test.ts @@ -1,16 +1,16 @@ -import type { DuplexStreamHandler, JsonRpcMessage } from '@/rpc/types'; +import type { DuplexStreamHandler, JsonRpcMessage } from '@/RPC/types'; import type { JSONValue } from '@/types'; import type { ConnectionInfo, Host, Port } from '@/network/types'; import type { NodeId } from '@/ids'; import type { ReadableWritablePair } from 'stream/web'; import { testProp, fc } from '@fast-check/jest'; import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; -import Rpc from '@/rpc/Rpc'; -import * as rpcErrors from '@/rpc/errors'; +import RPC from '@/RPC/RPC'; +import * as rpcErrors from '@/RPC/errors'; import * as rpcTestUtils from './utils'; -describe(`${Rpc.name}`, () => { - const logger = new Logger(`${Rpc.name} Test`, LogLevel.WARN, [ +describe(`${RPC.name}`, () => { + const logger = new Logger(`${RPC.name} Test`, LogLevel.WARN, [ new StreamHandler(), ]); @@ -21,10 +21,10 @@ describe(`${Rpc.name}`, () => { }) .noShrink(); - testProp('can stream data', [specificMessageArb], async (messages) => { + testProp.only('can stream data', [specificMessageArb], async (messages) => { const stream = rpcTestUtils.jsonRpcStream(messages); const container = {}; - const rpc = await Rpc.createRpc({ container, logger }); + const rpc = await RPC.createRpc({ container, logger }); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); const readWriteStream: ReadableWritablePair = { readable: stream, @@ -35,6 +35,7 @@ describe(`${Rpc.name}`, () => { async function* (input, _container, _connectionInfo, _ctx) { for await (const val of input) { yield val; + break; } }; @@ -53,7 +54,7 @@ describe(`${Rpc.name}`, () => { B: Symbol('b'), C: Symbol('c'), }; - const rpc = await Rpc.createRpc({ container, logger }); + const rpc = await RPC.createRpc({ container, logger }); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); const readWriteStream: ReadableWritablePair = { readable: stream, @@ -88,7 +89,7 @@ describe(`${Rpc.name}`, () => { remotePort: 12341 as Port, }; const container = {}; - const rpc = await Rpc.createRpc({ container, logger }); + const rpc = await RPC.createRpc({ container, logger }); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); const readWriteStream: ReadableWritablePair = { readable: stream, @@ -116,7 +117,7 @@ describe(`${Rpc.name}`, () => { async (messages) => { const stream = rpcTestUtils.jsonRpcStream(messages); const container = {}; - const rpc = await Rpc.createRpc({ container, logger }); + const rpc = await RPC.createRpc({ container, logger }); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); let thing; let lastMessage: JsonRpcMessage | undefined; @@ -158,7 +159,7 @@ describe(`${Rpc.name}`, () => { testProp('Handler yields nothing', [specificMessageArb], async (messages) => { const stream = rpcTestUtils.jsonRpcStream(messages); const container = {}; - const rpc = await Rpc.createRpc({ container, logger }); + const rpc = await RPC.createRpc({ container, logger }); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); const readWriteStream: ReadableWritablePair = { readable: stream, diff --git a/tests/rpc/utils.test.ts b/tests/RPC/utils.test.ts similarity index 100% rename from tests/rpc/utils.test.ts rename to tests/RPC/utils.test.ts diff --git a/tests/rpc/utils.ts b/tests/RPC/utils.ts similarity index 99% rename from tests/rpc/utils.ts rename to tests/RPC/utils.ts index 06e16ed3a..d9b91be0b 100644 --- a/tests/rpc/utils.ts +++ b/tests/RPC/utils.ts @@ -11,7 +11,7 @@ import type { JsonRpcRequest, JsonRpcResponseError, JsonRpcResponseResult, -} from '@/rpc/types'; +} from '@/RPC/types'; import type { JsonValue } from 'fast-check'; import { ReadableStream, WritableStream, TransformStream } from 'stream/web'; import { fc } from '@fast-check/jest'; From 1764e3814a1f079c58b0a1053bccb6e7564207d2 Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Wed, 11 Jan 2023 15:00:20 +1100 Subject: [PATCH 05/44] fix: fixing up stream parsing --- src/RPC/utils.ts | 72 +++++++++++------------------------------ tests/RPC/Rpc.test.ts | 2 +- tests/RPC/utils.test.ts | 7 ++-- 3 files changed, 24 insertions(+), 57 deletions(-) diff --git a/src/RPC/utils.ts b/src/RPC/utils.ts index 4dc1299b7..378e73dee 100644 --- a/src/RPC/utils.ts +++ b/src/RPC/utils.ts @@ -1,4 +1,8 @@ -import type { Transformer, TransformerTransformCallback } from 'stream/web'; +import type { + Transformer, + TransformerTransformCallback, + TransformerStartCallback, +} from 'stream/web'; import type { JsonRpcError, JsonRpcMessage, @@ -10,71 +14,31 @@ import type { import type { JSONValue } from '../types'; import { TransformStream } from 'stream/web'; import * as rpcErrors from './errors'; +import * as rpcUtils from './utils'; import * as utils from '../utils'; import * as validationErrors from '../validation/errors'; -import { promise } from '../utils'; const jsonStreamParsers = require('@streamparser/json'); class JsonToJsonMessage implements Transformer { - protected buffer = Buffer.alloc(0); + protected parser = new jsonStreamParsers.JSONParser({ separator: '' }); - /** - * This function finds the index of the closing `}` bracket of the top level - * JSON object. It makes use of a JSON parser tokenizer to find the `{}` - * tokens that are not within strings and counts them to find the top level - * matching `{}` pair. - */ - protected async findCompleteMessageIndex(input: Buffer): Promise { - const tokenizer = new jsonStreamParsers.Tokenizer(); - let braceCount = 0; - let escapes = 0; - const foundOffset = promise(); - tokenizer.onToken = (tokenData) => { - if (tokenData.token === jsonStreamParsers.TokenType.LEFT_BRACE) { - braceCount += 1; - } else if (tokenData.token === jsonStreamParsers.TokenType.RIGHT_BRACE) { - braceCount += -1; - if (braceCount === 0) foundOffset.resolveP(tokenData.offset + escapes); - } else if (tokenData.token === jsonStreamParsers.TokenType.STRING) { - const string = tokenData.value as string; - // `JSON.stringify` changes the length of a string when special - // characters are present. This makes the offset we find wrong when - // getting the substring. We need to compensate for this by getting the - // difference in string length. - escapes += JSON.stringify([string]).length - string.length - 4; + start: TransformerStartCallback = async (controller) => { + this.parser.onValue = (value) => { + if (value.parent === undefined) { + const jsonMessage = rpcUtils.parseJsonRpcMessage(value.value); + controller.enqueue(jsonMessage); } }; - tokenizer.onEnd = () => foundOffset.resolveP(-1); - try { - tokenizer.write(input); - } catch (e) { - throw new rpcErrors.ErrorRpcParse('TMP StreamParseError', { cause: e }); - } - try { - tokenizer.end(); - } catch { - foundOffset.resolveP(-1); - } - return await foundOffset.p; - } + }; transform: TransformerTransformCallback = async ( chunk, - controller, + _controller, ) => { - this.buffer = Buffer.concat([this.buffer, chunk]); - while (this.buffer.length > 0) { - const index = await this.findCompleteMessageIndex(this.buffer); - if (index <= 0) break; - const outputBuffer = this.buffer.subarray(0, index + 1); - try { - const jsonObject = JSON.parse(outputBuffer.toString('utf-8')); - const jsonRpcMessage = parseJsonRpcMessage(jsonObject); - controller.enqueue(jsonRpcMessage); - } catch (e) { - throw new rpcErrors.ErrorRpcParse(undefined, { cause: e }); - } - this.buffer = this.buffer.subarray(index + 1); + try { + this.parser.write(chunk); + } catch (e) { + throw new rpcErrors.ErrorRpcParse(undefined, { cause: e }); } }; } diff --git a/tests/RPC/Rpc.test.ts b/tests/RPC/Rpc.test.ts index 7f0a42eda..63c0810da 100644 --- a/tests/RPC/Rpc.test.ts +++ b/tests/RPC/Rpc.test.ts @@ -21,7 +21,7 @@ describe(`${RPC.name}`, () => { }) .noShrink(); - testProp.only('can stream data', [specificMessageArb], async (messages) => { + testProp('can stream data', [specificMessageArb], async (messages) => { const stream = rpcTestUtils.jsonRpcStream(messages); const container = {}; const rpc = await RPC.createRpc({ container, logger }); diff --git a/tests/RPC/utils.test.ts b/tests/RPC/utils.test.ts index 75276b49c..169b82ab6 100644 --- a/tests/RPC/utils.test.ts +++ b/tests/RPC/utils.test.ts @@ -1,8 +1,8 @@ import { testProp, fc } from '@fast-check/jest'; import { AsyncIterableX as AsyncIterable } from 'ix/asynciterable'; -import * as rpcUtils from '@/rpc/utils'; +import * as rpcUtils from '@/RPC/utils'; import 'ix/add/asynciterable-operators/toarray'; -import * as rpcErrors from '@/rpc/errors'; +import * as rpcErrors from '@/RPC/errors'; import * as rpcTestUtils from './utils'; describe('utils tests', () => { @@ -67,4 +67,7 @@ describe('utils tests', () => { }, { numRuns: 1000 }, ); + + // TODO: + // - Test for badly structured data }); From eb60e1fae320ba8465952dbd26425b629760207c Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Wed, 11 Jan 2023 15:10:34 +1100 Subject: [PATCH 06/44] fix: swapped client and server streaming logic, it was reversed --- src/RPC/RPC.ts | 12 ++++++------ src/RPC/types.ts | 6 +++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/RPC/RPC.ts b/src/RPC/RPC.ts index 6c0315ca9..692c1a0b6 100644 --- a/src/RPC/RPC.ts +++ b/src/RPC/RPC.ts @@ -1,11 +1,11 @@ import type { - ClientStreamHandler, + ServerStreamHandler, DuplexStreamHandler, JsonRpcError, JsonRpcMessage, JsonRpcResponseError, JsonRpcResponseResult, - ServerStreamHandler, + ClientStreamHandler, } from './types'; import type { ReadableWritablePair } from 'stream/web'; import type { JSONValue, POJO } from '../types'; @@ -112,9 +112,9 @@ class RPC { } @ready(new rpcErrors.ErrorRpcNotRunning()) - public registerClientStreamHandler( + public registerServerStreamHandler( method: string, - handler: ClientStreamHandler, + handler: ServerStreamHandler, ) { const wrapperDuplex: DuplexStreamHandler = async function* ( input, @@ -131,9 +131,9 @@ class RPC { } @ready(new rpcErrors.ErrorRpcNotRunning()) - public registerServerStreamHandler( + public registerClientStreamHandler( method: string, - handler: ServerStreamHandler, + handler: ClientStreamHandler, ) { const wrapperDuplex: DuplexStreamHandler = async function* ( input, diff --git a/src/RPC/types.ts b/src/RPC/types.ts index 71f5ebb21..9aebab06c 100644 --- a/src/RPC/types.ts +++ b/src/RPC/types.ts @@ -112,11 +112,11 @@ type DuplexStreamHandler = Handler< AsyncGenerator, AsyncGenerator >; -type ClientStreamHandler = Handler< +type ServerStreamHandler = Handler< I, AsyncGenerator >; -type ServerStreamHandler = Handler< +type ClientStreamHandler = Handler< AsyncGenerator, Promise >; @@ -134,7 +134,7 @@ export type { JsonRpcResponse, JsonRpcMessage, DuplexStreamHandler, - ClientStreamHandler, ServerStreamHandler, + ClientStreamHandler, UnaryHandler, }; From abb6202a2739e5f584a735d7aefee81780e7fd2a Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Wed, 11 Jan 2023 17:54:58 +1100 Subject: [PATCH 07/44] tests: creating tests for handlers --- tests/RPC/Rpc.test.ts | 96 +++++++++++++++++++++++++++++++++++++++---- 1 file changed, 88 insertions(+), 8 deletions(-) diff --git a/tests/RPC/Rpc.test.ts b/tests/RPC/Rpc.test.ts index 63c0810da..9b49c2104 100644 --- a/tests/RPC/Rpc.test.ts +++ b/tests/RPC/Rpc.test.ts @@ -1,4 +1,10 @@ -import type { DuplexStreamHandler, JsonRpcMessage } from '@/RPC/types'; +import type { + ClientStreamHandler, + DuplexStreamHandler, + JsonRpcMessage, + ServerStreamHandler, + UnaryHandler +} from "@/RPC/types"; import type { JSONValue } from '@/types'; import type { ConnectionInfo, Host, Port } from '@/network/types'; import type { NodeId } from '@/ids'; @@ -21,7 +27,7 @@ describe(`${RPC.name}`, () => { }) .noShrink(); - testProp('can stream data', [specificMessageArb], async (messages) => { + testProp('can stream data with duplex stream handler', [specificMessageArb], async (messages) => { const stream = rpcTestUtils.jsonRpcStream(messages); const container = {}; const rpc = await RPC.createRpc({ container, logger }); @@ -44,6 +50,83 @@ describe(`${RPC.name}`, () => { await outputResult; }); + testProp('can stream data with client stream handler', [specificMessageArb], async (messages) => { + const stream = rpcTestUtils.jsonRpcStream(messages); + const container = {}; + const rpc = await RPC.createRpc({ container, logger }); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const readWriteStream: ReadableWritablePair = { + readable: stream, + writable: outputStream, + }; + + const clientHandler: ClientStreamHandler = + async function (input, _container, _connectionInfo, _ctx) { + let count = 0; + for await (const val of input) { + count += 1; + } + return count; + }; + + rpc.registerClientStreamHandler(methodName, clientHandler); + rpc.handleStream(readWriteStream, {} as ConnectionInfo); + await outputResult + }); + + const singleNumberMessageArb = fc.array( + rpcTestUtils.jsonRpcRequestArb( + fc.constant(methodName), + fc.integer({min: 1, max: 20}), + ), + { + minLength: 2, + maxLength: 10, + } + ) + + testProp('can stream data with server stream handler', [singleNumberMessageArb], async (messages) => { + const stream = rpcTestUtils.jsonRpcStream(messages); + const container = {}; + const rpc = await RPC.createRpc({ container, logger }); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const readWriteStream: ReadableWritablePair = { + readable: stream, + writable: outputStream, + }; + + const serverHandler: ServerStreamHandler = + async function* (input, _container, _connectionInfo, _ctx) { + for (let i = 0; i < input; i++) { + yield i; + } + }; + + rpc.registerServerStreamHandler(methodName, serverHandler); + rpc.handleStream(readWriteStream, {} as ConnectionInfo); + await outputResult + }); + + testProp('can stream data with server stream handler', [specificMessageArb], async (messages) => { + const stream = rpcTestUtils.jsonRpcStream(messages); + const container = {}; + const rpc = await RPC.createRpc({ container, logger }); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const readWriteStream: ReadableWritablePair = { + readable: stream, + writable: outputStream, + }; + + const unaryHandler: UnaryHandler = + async function (input, _container, _connectionInfo, _ctx) { + return input; + }; + + rpc.registerUnaryHandler(methodName, unaryHandler); + rpc.handleStream(readWriteStream, {} as ConnectionInfo); + await outputResult; + }); + testProp( 'Handler is provided with container', [specificMessageArb], @@ -177,13 +260,10 @@ describe(`${RPC.name}`, () => { rpc.handleStream(readWriteStream, {} as ConnectionInfo); await outputResult; // We're just expecting no errors - }); + } + ); // TODO: - // - Test for each type of handler - // - duplex - // - client stream - // - server stream - // - unary // - Test odd conditions for handlers, like extra messages where 1 is expected. + // - Expectations can't be inside the handlers otherwise they're caught. }); From 74f10b8fd1b6accd36b10dc9cb404b7738f98b7f Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Thu, 12 Jan 2023 19:05:00 +1100 Subject: [PATCH 08/44] feat: creating `RPCClient` --- src/RPC/RPCClient.ts | 115 ++++++++++++++++++ src/RPC/{RPC.ts => RPCServer.ts} | 16 +-- tests/RPC/Rpc.test.ts | 198 +++++++++++++++++-------------- tests/RPC/utils.ts | 4 +- 4 files changed, 235 insertions(+), 98 deletions(-) create mode 100644 src/RPC/RPCClient.ts rename src/RPC/{RPC.ts => RPCServer.ts} (97%) diff --git a/src/RPC/RPCClient.ts b/src/RPC/RPCClient.ts new file mode 100644 index 000000000..4ec774eab --- /dev/null +++ b/src/RPC/RPCClient.ts @@ -0,0 +1,115 @@ +import type { PromiseCancellable } from '@matrixai/async-cancellable'; +import type { JSONValue, POJO } from 'types'; +import type { + ReadableWritablePair, + ReadableStream, + WritableStream, +} from 'stream/web'; +import { StartStop } from '@matrixai/async-init/dist/StartStop'; +import * as rpcErrors from 'RPC/errors'; +import Logger from '@matrixai/logger'; + +type QuicConnection = { + establishStream: (stream: ReadableWritablePair) => Promise; +}; + +interface RPCServer extends StartStop {} +@StartStop() +class RPCServer { + static async createRPCClient({ + quicConnection, + logger = new Logger(this.name), + }: { + quicConnection: QuicConnection; + logger: Logger; + }) { + logger.info(`Creating ${this.name}`); + const rpcClient = new this({ + quicConnection, + logger, + }); + await rpcClient.start(); + logger.info(`Created ${this.name}`); + return rpcClient; + } + + protected logger: Logger; + protected activeStreams: Set> = new Set(); + protected quicConnection: QuicConnection; + + public constructor({ + quicConnection, + logger, + }: { + quicConnection: QuicConnection; + logger: Logger; + }) { + this.logger = logger; + this.quicConnection = quicConnection; + } + + public async start(): Promise { + this.logger.info(`Starting ${this.constructor.name}`); + this.logger.info(`Started ${this.constructor.name}`); + } + + public async stop(): Promise { + this.logger.info(`Stopping ${this.constructor.name}`); + for await (const [stream] of this.activeStreams.entries()) { + stream.cancel(new rpcErrors.ErrorRpcStopping()); + } + for await (const [stream] of this.activeStreams.entries()) { + await stream; + } + this.logger.info(`Stopped ${this.constructor.name}`); + } + + protected duplexCaller( + method: string, + metadata: POJO, + ): AsyncGenerator { + // The stream pair is the interface with the quic system. The readable is + // considered the output while the writeable is the input to the caller. + const pair: ReadableWritablePair = { + readable: {} as ReadableStream, + writable: {} as WritableStream, + }; + + const inputGen = async function* (): AsyncGenerator { + const writer = pair.writable.getWriter(); + let value: I | null; + try { + while (true) { + value = yield; + if (value === null) break; + await writer.write(value); + } + await writer.close(); + } catch (e) { + await writer.abort(e); + } + }; + + const outputGen = async function* (): AsyncGenerator { + const reader = pair.readable.getReader(); + while (true) { + const { value, done } = await reader.read(); + if (done) break; + yield value; + } + }; + const output = outputGen(); + const input = inputGen(); + + const inter = { + read: () => output.next(), + write: (value: I | null) => input.next(value), + }; + + const duplexGenerator = async function* (): AsyncGenerator { + const otherThing: O = {} as O; + const thing = yield otherThing; + }; + return duplexGenerator(); + } +} diff --git a/src/RPC/RPC.ts b/src/RPC/RPCServer.ts similarity index 97% rename from src/RPC/RPC.ts rename to src/RPC/RPCServer.ts index 692c1a0b6..a520294f9 100644 --- a/src/RPC/RPC.ts +++ b/src/RPC/RPCServer.ts @@ -23,27 +23,27 @@ import * as rpcUtils from './utils'; import * as grpcUtils from '../grpc/utils'; // FIXME: Might need to be StartStop. Won't know for sure until it's used. -interface RPC extends CreateDestroyStartStop {} +interface RPCServer extends CreateDestroyStartStop {} @CreateDestroyStartStop( new rpcErrors.ErrorRpcRunning(), new rpcErrors.ErrorRpcDestroyed(), ) -class RPC { - static async createRpc({ +class RPCServer { + static async createRPCServer({ container, logger = new Logger(this.name), }: { container: POJO; logger?: Logger; - }): Promise { + }): Promise { logger.info(`Creating ${this.name}`); - const rpc = new this({ + const rpcServer = new this({ container, logger, }); - await rpc.start(); + await rpcServer.start(); logger.info(`Created ${this.name}`); - return rpc; + return rpcServer; } // Properties @@ -265,4 +265,4 @@ class RPC { } } -export default RPC; +export default RPCServer; diff --git a/tests/RPC/Rpc.test.ts b/tests/RPC/Rpc.test.ts index 9b49c2104..528c0565b 100644 --- a/tests/RPC/Rpc.test.ts +++ b/tests/RPC/Rpc.test.ts @@ -3,20 +3,20 @@ import type { DuplexStreamHandler, JsonRpcMessage, ServerStreamHandler, - UnaryHandler -} from "@/RPC/types"; + UnaryHandler, +} from '@/RPC/types'; import type { JSONValue } from '@/types'; import type { ConnectionInfo, Host, Port } from '@/network/types'; import type { NodeId } from '@/ids'; import type { ReadableWritablePair } from 'stream/web'; import { testProp, fc } from '@fast-check/jest'; import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; -import RPC from '@/RPC/RPC'; +import RPCServer from '@/RPC/RPCServer'; import * as rpcErrors from '@/RPC/errors'; import * as rpcTestUtils from './utils'; -describe(`${RPC.name}`, () => { - const logger = new Logger(`${RPC.name} Test`, LogLevel.WARN, [ +describe(`${RPCServer.name}`, () => { + const logger = new Logger(`${RPCServer.name} Test`, LogLevel.WARN, [ new StreamHandler(), ]); @@ -27,105 +27,125 @@ describe(`${RPC.name}`, () => { }) .noShrink(); - testProp('can stream data with duplex stream handler', [specificMessageArb], async (messages) => { - const stream = rpcTestUtils.jsonRpcStream(messages); - const container = {}; - const rpc = await RPC.createRpc({ container, logger }); - const [outputResult, outputStream] = rpcTestUtils.streamToArray(); - const readWriteStream: ReadableWritablePair = { - readable: stream, - writable: outputStream, - }; - - const duplexHandler: DuplexStreamHandler = - async function* (input, _container, _connectionInfo, _ctx) { - for await (const val of input) { - yield val; - break; - } + testProp( + 'can stream data with duplex stream handler', + [specificMessageArb], + async (messages) => { + const stream = rpcTestUtils.jsonRpcStream(messages); + const container = {}; + const rpc = await RPCServer.createRPCServer({ container, logger }); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const readWriteStream: ReadableWritablePair = { + readable: stream, + writable: outputStream, }; - rpc.registerDuplexStreamHandler(methodName, duplexHandler); - rpc.handleStream(readWriteStream, {} as ConnectionInfo); - await outputResult; - }); + const duplexHandler: DuplexStreamHandler = + async function* (input, _container, _connectionInfo, _ctx) { + for await (const val of input) { + yield val; + break; + } + }; - testProp('can stream data with client stream handler', [specificMessageArb], async (messages) => { - const stream = rpcTestUtils.jsonRpcStream(messages); - const container = {}; - const rpc = await RPC.createRpc({ container, logger }); - const [outputResult, outputStream] = rpcTestUtils.streamToArray(); - const readWriteStream: ReadableWritablePair = { - readable: stream, - writable: outputStream, - }; + rpc.registerDuplexStreamHandler(methodName, duplexHandler); + rpc.handleStream(readWriteStream, {} as ConnectionInfo); + await outputResult; + }, + ); - const clientHandler: ClientStreamHandler = - async function (input, _container, _connectionInfo, _ctx) { - let count = 0; - for await (const val of input) { - count += 1; - } - return count; + testProp( + 'can stream data with client stream handler', + [specificMessageArb], + async (messages) => { + const stream = rpcTestUtils.jsonRpcStream(messages); + const container = {}; + const rpc = await RPCServer.createRPCServer({ container, logger }); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const readWriteStream: ReadableWritablePair = { + readable: stream, + writable: outputStream, }; - rpc.registerClientStreamHandler(methodName, clientHandler); - rpc.handleStream(readWriteStream, {} as ConnectionInfo); - await outputResult - }); + const clientHandler: ClientStreamHandler = + async function (input, _container, _connectionInfo, _ctx) { + let count = 0; + for await (const _ of input) { + count += 1; + } + return count; + }; + + rpc.registerClientStreamHandler(methodName, clientHandler); + rpc.handleStream(readWriteStream, {} as ConnectionInfo); + await outputResult; + }, + ); const singleNumberMessageArb = fc.array( rpcTestUtils.jsonRpcRequestArb( fc.constant(methodName), - fc.integer({min: 1, max: 20}), + fc.integer({ min: 1, max: 20 }), ), { minLength: 2, maxLength: 10, - } - ) - - testProp('can stream data with server stream handler', [singleNumberMessageArb], async (messages) => { - const stream = rpcTestUtils.jsonRpcStream(messages); - const container = {}; - const rpc = await RPC.createRpc({ container, logger }); - const [outputResult, outputStream] = rpcTestUtils.streamToArray(); - const readWriteStream: ReadableWritablePair = { - readable: stream, - writable: outputStream, - }; + }, + ); - const serverHandler: ServerStreamHandler = - async function* (input, _container, _connectionInfo, _ctx) { - for (let i = 0; i < input; i++) { - yield i; - } + testProp( + 'can stream data with server stream handler', + [singleNumberMessageArb], + async (messages) => { + const stream = rpcTestUtils.jsonRpcStream(messages); + const container = {}; + const rpc = await RPCServer.createRPCServer({ container, logger }); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const readWriteStream: ReadableWritablePair = { + readable: stream, + writable: outputStream, }; - rpc.registerServerStreamHandler(methodName, serverHandler); - rpc.handleStream(readWriteStream, {} as ConnectionInfo); - await outputResult - }); + const serverHandler: ServerStreamHandler = + async function* (input, _container, _connectionInfo, _ctx) { + for (let i = 0; i < input; i++) { + yield i; + } + }; - testProp('can stream data with server stream handler', [specificMessageArb], async (messages) => { - const stream = rpcTestUtils.jsonRpcStream(messages); - const container = {}; - const rpc = await RPC.createRpc({ container, logger }); - const [outputResult, outputStream] = rpcTestUtils.streamToArray(); - const readWriteStream: ReadableWritablePair = { - readable: stream, - writable: outputStream, - }; + rpc.registerServerStreamHandler(methodName, serverHandler); + rpc.handleStream(readWriteStream, {} as ConnectionInfo); + await outputResult; + }, + ); + + testProp( + 'can stream data with server stream handler', + [specificMessageArb], + async (messages) => { + const stream = rpcTestUtils.jsonRpcStream(messages); + const container = {}; + const rpc = await RPCServer.createRPCServer({ container, logger }); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const readWriteStream: ReadableWritablePair = { + readable: stream, + writable: outputStream, + }; - const unaryHandler: UnaryHandler = - async function (input, _container, _connectionInfo, _ctx) { - return input; + const unaryHandler: UnaryHandler = async function ( + input, + _container, + _connectionInfo, + _ctx, + ) { + return input; }; - rpc.registerUnaryHandler(methodName, unaryHandler); - rpc.handleStream(readWriteStream, {} as ConnectionInfo); - await outputResult; - }); + rpc.registerUnaryHandler(methodName, unaryHandler); + rpc.handleStream(readWriteStream, {} as ConnectionInfo); + await outputResult; + }, + ); testProp( 'Handler is provided with container', @@ -137,7 +157,7 @@ describe(`${RPC.name}`, () => { B: Symbol('b'), C: Symbol('c'), }; - const rpc = await RPC.createRpc({ container, logger }); + const rpc = await RPCServer.createRPCServer({ container, logger }); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); const readWriteStream: ReadableWritablePair = { readable: stream, @@ -172,7 +192,7 @@ describe(`${RPC.name}`, () => { remotePort: 12341 as Port, }; const container = {}; - const rpc = await RPC.createRpc({ container, logger }); + const rpc = await RPCServer.createRPCServer({ container, logger }); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); const readWriteStream: ReadableWritablePair = { readable: stream, @@ -200,7 +220,7 @@ describe(`${RPC.name}`, () => { async (messages) => { const stream = rpcTestUtils.jsonRpcStream(messages); const container = {}; - const rpc = await RPC.createRpc({ container, logger }); + const rpc = await RPCServer.createRPCServer({ container, logger }); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); let thing; let lastMessage: JsonRpcMessage | undefined; @@ -242,7 +262,7 @@ describe(`${RPC.name}`, () => { testProp('Handler yields nothing', [specificMessageArb], async (messages) => { const stream = rpcTestUtils.jsonRpcStream(messages); const container = {}; - const rpc = await RPC.createRpc({ container, logger }); + const rpc = await RPCServer.createRPCServer({ container, logger }); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); const readWriteStream: ReadableWritablePair = { readable: stream, @@ -260,10 +280,10 @@ describe(`${RPC.name}`, () => { rpc.handleStream(readWriteStream, {} as ConnectionInfo); await outputResult; // We're just expecting no errors - } - ); + }); // TODO: // - Test odd conditions for handlers, like extra messages where 1 is expected. // - Expectations can't be inside the handlers otherwise they're caught. + // - get the tap transform stream working }); diff --git a/tests/RPC/utils.ts b/tests/RPC/utils.ts index d9b91be0b..81ce3913e 100644 --- a/tests/RPC/utils.ts +++ b/tests/RPC/utils.ts @@ -102,7 +102,9 @@ const jsonRpcStream = (messages: Array) => { const jsonRpcRequestArb = ( method: fc.Arbitrary = fc.string(), - params: fc.Arbitrary = fc.jsonValue(), + params: fc.Arbitrary = fc + .jsonValue() + .map((value) => JSON.parse(JSON.stringify(value))), ) => fc .record( From d2a1aefd915cdc76e5bc9f01022c65e3bfdc9e79 Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Mon, 16 Jan 2023 12:47:44 +1100 Subject: [PATCH 09/44] fix: updating parser path filtering --- src/RPC/utils.ts | 8 +++----- tests/RPC/{Rpc.test.ts => RPCServer.test.ts} | 0 2 files changed, 3 insertions(+), 5 deletions(-) rename tests/RPC/{Rpc.test.ts => RPCServer.test.ts} (100%) diff --git a/src/RPC/utils.ts b/src/RPC/utils.ts index 378e73dee..fe4bd43dd 100644 --- a/src/RPC/utils.ts +++ b/src/RPC/utils.ts @@ -20,14 +20,12 @@ import * as validationErrors from '../validation/errors'; const jsonStreamParsers = require('@streamparser/json'); class JsonToJsonMessage implements Transformer { - protected parser = new jsonStreamParsers.JSONParser({ separator: '' }); + protected parser = new jsonStreamParsers.JSONParser({ separator: '', paths: ['$'] }); start: TransformerStartCallback = async (controller) => { this.parser.onValue = (value) => { - if (value.parent === undefined) { - const jsonMessage = rpcUtils.parseJsonRpcMessage(value.value); - controller.enqueue(jsonMessage); - } + const jsonMessage = rpcUtils.parseJsonRpcMessage(value.value); + controller.enqueue(jsonMessage); }; }; diff --git a/tests/RPC/Rpc.test.ts b/tests/RPC/RPCServer.test.ts similarity index 100% rename from tests/RPC/Rpc.test.ts rename to tests/RPC/RPCServer.test.ts From 28f7ebae860a92d40b982fb8ee91ab952395392b Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Mon, 16 Jan 2023 14:17:12 +1100 Subject: [PATCH 10/44] feat: generic client callers Related #501 [ci skip] --- src/RPC/RPCClient.ts | 161 ++++++++++++++++++++++++++++------- src/RPC/RPCServer.ts | 4 + src/RPC/types.ts | 37 ++++++++ src/RPC/utils.ts | 5 +- tests/RPC/RPCClient.test.ts | 165 ++++++++++++++++++++++++++++++++++++ 5 files changed, 340 insertions(+), 32 deletions(-) create mode 100644 tests/RPC/RPCClient.test.ts diff --git a/src/RPC/RPCClient.ts b/src/RPC/RPCClient.ts index 4ec774eab..6e293d3c8 100644 --- a/src/RPC/RPCClient.ts +++ b/src/RPC/RPCClient.ts @@ -1,21 +1,24 @@ +import type { + DuplexCallerInterface, + ServerCallerInterface, + ClientCallerInterface, + JsonRpcRequest, +} from './types'; import type { PromiseCancellable } from '@matrixai/async-cancellable'; import type { JSONValue, POJO } from 'types'; -import type { - ReadableWritablePair, - ReadableStream, - WritableStream, -} from 'stream/web'; +import type { ReadableWritablePair } from 'stream/web'; import { StartStop } from '@matrixai/async-init/dist/StartStop'; -import * as rpcErrors from 'RPC/errors'; import Logger from '@matrixai/logger'; +import * as rpcErrors from './errors'; +import * as rpcUtils from './utils'; type QuicConnection = { - establishStream: (stream: ReadableWritablePair) => Promise; + // EstablishStream: (stream: ReadableWritablePair) => Promise; }; -interface RPCServer extends StartStop {} +interface RPCClient extends StartStop {} @StartStop() -class RPCServer { +class RPCClient { static async createRPCClient({ quicConnection, logger = new Logger(this.name), @@ -64,52 +67,148 @@ class RPCServer { this.logger.info(`Stopped ${this.constructor.name}`); } - protected duplexCaller( + public async duplexStreamCaller( method: string, metadata: POJO, - ): AsyncGenerator { - // The stream pair is the interface with the quic system. The readable is - // considered the output while the writeable is the input to the caller. - const pair: ReadableWritablePair = { - readable: {} as ReadableStream, - writable: {} as WritableStream, - }; + streamPair: ReadableWritablePair, + ): Promise> { + const inputStream = streamPair.readable.pipeThrough( + new rpcUtils.JsonToJsonMessageStream(), + ); + const outputTransform = new rpcUtils.JsonMessageToJsonStream(); + void outputTransform.readable.pipeTo(streamPair.writable); - const inputGen = async function* (): AsyncGenerator { - const writer = pair.writable.getWriter(); - let value: I | null; + const inputGen = async function* (): AsyncGenerator { + const writer = outputTransform.writable.getWriter(); + let value: I; try { while (true) { value = yield; - if (value === null) break; - await writer.write(value); + const message: JsonRpcRequest = { + method, + type: 'JsonRpcRequest', + jsonrpc: '2.0', + id: null, + params: value, + }; + await writer.write(message); } - await writer.close(); } catch (e) { await writer.abort(e); + } finally { + await writer.close(); } }; const outputGen = async function* (): AsyncGenerator { - const reader = pair.readable.getReader(); + const reader = inputStream.getReader(); while (true) { const { value, done } = await reader.read(); if (done) break; - yield value; + if ( + value?.type === 'JsonRpcRequest' || + value?.type === 'JsonRpcNotification' + ) { + yield value.params as O; + } } }; const output = outputGen(); const input = inputGen(); + // Initiating the input generator + await input.next(); - const inter = { + const inter: DuplexCallerInterface = { read: () => output.next(), - write: (value: I | null) => input.next(value), + write: async (value: I) => { + await input.next(value); + }, + inputGenerator: input, + outputGenerator: output, + end: async () => { + await input.return(); + }, + close: async () => { + await output.return(); + }, + throw: async (reason: any) => { + await input.throw(reason); + await output.throw(reason); + }, + }; + return inter; + } + + public async serverStreamCaller( + method: string, + parameters: I, + metadata: POJO, + streamPair: ReadableWritablePair, + ): Promise> { + const callerInterface = await this.duplexStreamCaller( + method, + metadata, + streamPair, + ); + await callerInterface.write(parameters); + await callerInterface.end(); + + return { + read: () => callerInterface.read(), + outputGenerator: callerInterface.outputGenerator, + close: () => callerInterface.close(), + throw: async (reason: any) => { + await callerInterface.outputGenerator.throw(reason); + }, }; + } - const duplexGenerator = async function* (): AsyncGenerator { - const otherThing: O = {} as O; - const thing = yield otherThing; + public async clientStreamCaller( + method: string, + metadata: POJO, + streamPair: ReadableWritablePair, + ): Promise> { + const callerInterface = await this.duplexStreamCaller( + method, + metadata, + streamPair, + ); + const output = callerInterface + .read() + .then(({ value, done }) => { + if (done) throw Error('TMP Stream closed early'); + return value; + }) + .finally(async () => { + await callerInterface.close(); + }); + return { + write: (value: I) => callerInterface.write(value), + result: output, + inputGenerator: callerInterface.inputGenerator, + end: () => callerInterface.end(), + throw: (reason: any) => callerInterface.throw(reason), }; - return duplexGenerator(); + } + + public async unaryCaller( + method: string, + parameters: I, + metadata: POJO, + streamPair: ReadableWritablePair, + ): Promise { + const callerInterface = await this.duplexStreamCaller( + method, + metadata, + streamPair, + ); + await callerInterface.write(parameters); + const output = await callerInterface.read(); + if (output.done) throw Error('TMP stream ended early'); + await callerInterface.end(); + await callerInterface.close(); + return output.value; } } + +export default RPCClient; diff --git a/src/RPC/RPCServer.ts b/src/RPC/RPCServer.ts index a520294f9..a55b19ab8 100644 --- a/src/RPC/RPCServer.ts +++ b/src/RPC/RPCServer.ts @@ -153,6 +153,7 @@ class RPCServer { ) { // This will take a buffer stream of json messages and set up service // handling for it. + // Constructing the PromiseCancellable for tracking the active stream let resolve: (value: void | PromiseLike) => void; const abortController = new AbortController(); const handlerProm2: PromiseCancellable = new PromiseCancellable( @@ -161,6 +162,7 @@ class RPCServer { }, abortController, ); + // Putting the PromiseCancellable into the active streams map this.activeStreams.add(handlerProm2); void handlerProm2.finally(() => this.activeStreams.delete(handlerProm2)); // While ReadableStream can be converted to AsyncIterable, we want it as @@ -170,6 +172,8 @@ class RPCServer { new rpcUtils.JsonToJsonMessageStream(), ); for await (const dataMessage of pojoStream) { + // FIXME: don't bother filtering, we should assume all input messages are request or notification. + // These should be checked by parsing, no need for a type field. // Filtering for request and notification messages if ( dataMessage.type === 'JsonRpcRequest' || diff --git a/src/RPC/types.ts b/src/RPC/types.ts index 9aebab06c..b877717fa 100644 --- a/src/RPC/types.ts +++ b/src/RPC/types.ts @@ -125,6 +125,40 @@ type UnaryHandler = Handler< Promise >; +/** + * @property read Read from the output generator + * @property write Write to the input generator + * @property inputGenerator Low level access to the input generator + * @property outputGenerator Low level access to the output generator + * @property end Signal end to the input generator + * @property close Signal early close to the output generator + * @property throw Throw to both generators + */ +type DuplexCallerInterface = { + read: () => Promise>; + write: (value: I) => Promise; + inputGenerator: AsyncGenerator; + outputGenerator: AsyncGenerator; + end: () => Promise; + close: () => Promise; + throw: (reason: any) => Promise; +}; + +type ServerCallerInterface = { + read: () => Promise>; + outputGenerator: AsyncGenerator; + close: () => Promise; + throw: (reason: any) => Promise; +}; + +type ClientCallerInterface = { + write: (value: I) => Promise; + result: Promise; + inputGenerator: AsyncGenerator; + end: () => Promise; + throw: (reason: any) => Promise; +}; + export type { JsonRpcRequest, JsonRpcNotification, @@ -137,4 +171,7 @@ export type { ServerStreamHandler, ClientStreamHandler, UnaryHandler, + DuplexCallerInterface, + ServerCallerInterface, + ClientCallerInterface, }; diff --git a/src/RPC/utils.ts b/src/RPC/utils.ts index fe4bd43dd..e64b09474 100644 --- a/src/RPC/utils.ts +++ b/src/RPC/utils.ts @@ -20,7 +20,10 @@ import * as validationErrors from '../validation/errors'; const jsonStreamParsers = require('@streamparser/json'); class JsonToJsonMessage implements Transformer { - protected parser = new jsonStreamParsers.JSONParser({ separator: '', paths: ['$'] }); + protected parser = new jsonStreamParsers.JSONParser({ + separator: '', + paths: ['$'], + }); start: TransformerStartCallback = async (controller) => { this.parser.onValue = (value) => { diff --git a/tests/RPC/RPCClient.test.ts b/tests/RPC/RPCClient.test.ts new file mode 100644 index 000000000..9026d1106 --- /dev/null +++ b/tests/RPC/RPCClient.test.ts @@ -0,0 +1,165 @@ +import type { ReadableWritablePair } from 'stream/web'; +import type { JSONValue } from '@/types'; +import type { JsonRpcRequest } from '@/RPC/types'; +import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; +import { testProp, fc } from '@fast-check/jest'; +import RPCClient from '@/RPC/RPCClient'; +import RPCServer from '@/RPC/RPCServer'; +import * as rpcTestUtils from './utils'; + +describe(`${RPCClient.name}`, () => { + const logger = new Logger(`${RPCServer.name} Test`, LogLevel.WARN, [ + new StreamHandler(), + ]); + + const methodName = 'testMethod'; + const specificMessageArb = fc + .array(rpcTestUtils.jsonRpcRequestArb(), { + minLength: 5, + }) + .noShrink(); + + testProp('generic duplex caller', [specificMessageArb], async (messages) => { + const inputStream = rpcTestUtils.jsonRpcStream(messages); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const streamPair: ReadableWritablePair = { + readable: inputStream, + writable: outputStream, + }; + const rpcClient = await RPCClient.createRPCClient({ + logger, + quicConnection: {}, + }); + const callerInterface = await rpcClient.duplexStreamCaller< + JSONValue, + JSONValue + >(methodName, { hello: 'world' }, streamPair); + while (true) { + const { value, done } = await callerInterface.read(); + if (done) { + // We have to end the writer otherwise the stream never closes + await callerInterface.end(); + break; + } + await callerInterface.write(value); + } + const expectedMessages: Array = messages.map((v) => { + const request: JsonRpcRequest = { + type: 'JsonRpcRequest', + jsonrpc: '2.0', + method: methodName, + id: null, + ...(v.params === undefined ? {} : { params: v.params }), + }; + return request; + }); + const outputMessages = (await outputResult).map((v) => + JSON.parse(v.toString()), + ); + expect(outputMessages).toStrictEqual(expectedMessages); + }); + testProp( + 'generic server stream caller', + [specificMessageArb, fc.jsonValue()], + async (messages, params) => { + const inputStream = rpcTestUtils.jsonRpcStream(messages); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const streamPair: ReadableWritablePair = { + readable: inputStream, + writable: outputStream, + }; + const rpcClient = await RPCClient.createRPCClient({ + logger, + quicConnection: {}, + }); + const callerInterface = await rpcClient.serverStreamCaller< + JSONValue, + JSONValue + >(methodName, params as JSONValue, {}, streamPair); + const values: Array = []; + for await (const value of callerInterface.outputGenerator) { + values.push(value); + } + const expectedValues = messages.map((v) => v.params); + expect(values).toStrictEqual(expectedValues); + expect((await outputResult)[0]?.toString()).toStrictEqual( + JSON.stringify({ + method: methodName, + type: 'JsonRpcRequest', + jsonrpc: '2.0', + id: null, + params, + }), + ); + }, + ); + testProp( + 'generic client stream caller', + [rpcTestUtils.jsonRpcRequestArb(), fc.array(fc.jsonValue())], + async (message, params) => { + const inputStream = rpcTestUtils.jsonRpcStream([message]); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const streamPair: ReadableWritablePair = { + readable: inputStream, + writable: outputStream, + }; + const rpcClient = await RPCClient.createRPCClient({ + logger, + quicConnection: {}, + }); + const callerInterface = await rpcClient.clientStreamCaller< + JSONValue, + JSONValue + >(methodName, {}, streamPair); + for (const param of params) { + await callerInterface.write(param as JSONValue); + } + await callerInterface.end(); + expect(await callerInterface.result).toStrictEqual(message.params); + const expectedOutput = params.map((v) => + JSON.stringify({ + method: methodName, + type: 'JsonRpcRequest', + jsonrpc: '2.0', + id: null, + params: v, + }), + ); + expect((await outputResult).map((v) => v.toString())).toStrictEqual( + expectedOutput, + ); + }, + ); + testProp( + 'generic unary caller', + [rpcTestUtils.jsonRpcRequestArb(), fc.jsonValue()], + async (message, params) => { + const inputStream = rpcTestUtils.jsonRpcStream([message]); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const streamPair: ReadableWritablePair = { + readable: inputStream, + writable: outputStream, + }; + const rpcClient = await RPCClient.createRPCClient({ + logger, + quicConnection: {}, + }); + const result = await rpcClient.unaryCaller( + methodName, + params as JSONValue, + {}, + streamPair, + ); + expect(result).toStrictEqual(message.params); + expect((await outputResult)[0]?.toString()).toStrictEqual( + JSON.stringify({ + method: methodName, + type: 'JsonRpcRequest', + jsonrpc: '2.0', + id: null, + params: params, + }), + ); + }, + ); +}); From ca524bc8feab2df540f5d4dc416585f90cce475e Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Wed, 18 Jan 2023 16:47:24 +1100 Subject: [PATCH 11/44] feat: generic stream pair callback Related #501 [ci skip] --- src/RPC/RPCClient.ts | 37 +++++++++++++------------------------ src/RPC/types.ts | 6 ++++++ tests/RPC/RPCClient.test.ts | 15 +++++++-------- 3 files changed, 26 insertions(+), 32 deletions(-) diff --git a/src/RPC/RPCClient.ts b/src/RPC/RPCClient.ts index 6e293d3c8..4b8a0e319 100644 --- a/src/RPC/RPCClient.ts +++ b/src/RPC/RPCClient.ts @@ -1,34 +1,30 @@ import type { - DuplexCallerInterface, - ServerCallerInterface, ClientCallerInterface, + DuplexCallerInterface, JsonRpcRequest, + ServerCallerInterface, + StreamPairCreateCallback, } from './types'; import type { PromiseCancellable } from '@matrixai/async-cancellable'; import type { JSONValue, POJO } from 'types'; -import type { ReadableWritablePair } from 'stream/web'; import { StartStop } from '@matrixai/async-init/dist/StartStop'; import Logger from '@matrixai/logger'; import * as rpcErrors from './errors'; import * as rpcUtils from './utils'; -type QuicConnection = { - // EstablishStream: (stream: ReadableWritablePair) => Promise; -}; - interface RPCClient extends StartStop {} @StartStop() class RPCClient { static async createRPCClient({ - quicConnection, + streamPairCreateCallback, logger = new Logger(this.name), }: { - quicConnection: QuicConnection; + streamPairCreateCallback: StreamPairCreateCallback; logger: Logger; }) { logger.info(`Creating ${this.name}`); const rpcClient = new this({ - quicConnection, + streamPairCreateCallback, logger, }); await rpcClient.start(); @@ -38,17 +34,17 @@ class RPCClient { protected logger: Logger; protected activeStreams: Set> = new Set(); - protected quicConnection: QuicConnection; + protected streamPairCreateCallback: StreamPairCreateCallback; public constructor({ - quicConnection, + streamPairCreateCallback, logger, }: { - quicConnection: QuicConnection; + streamPairCreateCallback: StreamPairCreateCallback; logger: Logger; }) { this.logger = logger; - this.quicConnection = quicConnection; + this.streamPairCreateCallback = streamPairCreateCallback; } public async start(): Promise { @@ -69,9 +65,9 @@ class RPCClient { public async duplexStreamCaller( method: string, - metadata: POJO, - streamPair: ReadableWritablePair, + _metadata: POJO, ): Promise> { + const streamPair = await this.streamPairCreateCallback(); const inputStream = streamPair.readable.pipeThrough( new rpcUtils.JsonToJsonMessageStream(), ); @@ -118,7 +114,7 @@ class RPCClient { // Initiating the input generator await input.next(); - const inter: DuplexCallerInterface = { + return { read: () => output.next(), write: async (value: I) => { await input.next(value); @@ -136,19 +132,16 @@ class RPCClient { await output.throw(reason); }, }; - return inter; } public async serverStreamCaller( method: string, parameters: I, metadata: POJO, - streamPair: ReadableWritablePair, ): Promise> { const callerInterface = await this.duplexStreamCaller( method, metadata, - streamPair, ); await callerInterface.write(parameters); await callerInterface.end(); @@ -166,12 +159,10 @@ class RPCClient { public async clientStreamCaller( method: string, metadata: POJO, - streamPair: ReadableWritablePair, ): Promise> { const callerInterface = await this.duplexStreamCaller( method, metadata, - streamPair, ); const output = callerInterface .read() @@ -195,12 +186,10 @@ class RPCClient { method: string, parameters: I, metadata: POJO, - streamPair: ReadableWritablePair, ): Promise { const callerInterface = await this.duplexStreamCaller( method, metadata, - streamPair, ); await callerInterface.write(parameters); const output = await callerInterface.read(); diff --git a/src/RPC/types.ts b/src/RPC/types.ts index b877717fa..712b96d8a 100644 --- a/src/RPC/types.ts +++ b/src/RPC/types.ts @@ -1,6 +1,7 @@ import type { JSONValue, POJO } from '../types'; import type { ConnectionInfo } from '../network/types'; import type { ContextCancellable } from '../contexts/types'; +import type { ReadableWritablePair } from 'stream/web'; /** * This is the JSON RPC request object. this is the generic message type used for the RPC. @@ -159,6 +160,10 @@ type ClientCallerInterface = { throw: (reason: any) => Promise; }; +type StreamPairCreateCallback = () => Promise< + ReadableWritablePair +>; + export type { JsonRpcRequest, JsonRpcNotification, @@ -174,4 +179,5 @@ export type { DuplexCallerInterface, ServerCallerInterface, ClientCallerInterface, + StreamPairCreateCallback, }; diff --git a/tests/RPC/RPCClient.test.ts b/tests/RPC/RPCClient.test.ts index 9026d1106..dbb51cba5 100644 --- a/tests/RPC/RPCClient.test.ts +++ b/tests/RPC/RPCClient.test.ts @@ -27,13 +27,13 @@ describe(`${RPCClient.name}`, () => { writable: outputStream, }; const rpcClient = await RPCClient.createRPCClient({ + streamPairCreateCallback: async () => streamPair, logger, - quicConnection: {}, }); const callerInterface = await rpcClient.duplexStreamCaller< JSONValue, JSONValue - >(methodName, { hello: 'world' }, streamPair); + >(methodName, { hello: 'world' }); while (true) { const { value, done } = await callerInterface.read(); if (done) { @@ -69,13 +69,13 @@ describe(`${RPCClient.name}`, () => { writable: outputStream, }; const rpcClient = await RPCClient.createRPCClient({ + streamPairCreateCallback: async () => streamPair, logger, - quicConnection: {}, }); const callerInterface = await rpcClient.serverStreamCaller< JSONValue, JSONValue - >(methodName, params as JSONValue, {}, streamPair); + >(methodName, params as JSONValue, {}); const values: Array = []; for await (const value of callerInterface.outputGenerator) { values.push(value); @@ -104,13 +104,13 @@ describe(`${RPCClient.name}`, () => { writable: outputStream, }; const rpcClient = await RPCClient.createRPCClient({ + streamPairCreateCallback: async () => streamPair, logger, - quicConnection: {}, }); const callerInterface = await rpcClient.clientStreamCaller< JSONValue, JSONValue - >(methodName, {}, streamPair); + >(methodName, {}); for (const param of params) { await callerInterface.write(param as JSONValue); } @@ -141,14 +141,13 @@ describe(`${RPCClient.name}`, () => { writable: outputStream, }; const rpcClient = await RPCClient.createRPCClient({ + streamPairCreateCallback: async () => streamPair, logger, - quicConnection: {}, }); const result = await rpcClient.unaryCaller( methodName, params as JSONValue, {}, - streamPair, ); expect(result).toStrictEqual(message.params); expect((await outputResult)[0]?.toString()).toStrictEqual( From f1701629dc2a6016810190fea018653369f114f7 Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Wed, 18 Jan 2023 17:44:26 +1100 Subject: [PATCH 12/44] fix: enforcing RPC message size limit Related #501 Related #502 [ci skip] --- src/RPC/errors.ts | 6 ++++++ src/RPC/utils.ts | 14 ++++++++++++-- tests/RPC/utils.test.ts | 23 +++++++++++++++++++++++ 3 files changed, 41 insertions(+), 2 deletions(-) diff --git a/src/RPC/errors.ts b/src/RPC/errors.ts index 13549aeb4..166706771 100644 --- a/src/RPC/errors.ts +++ b/src/RPC/errors.ts @@ -37,6 +37,11 @@ class ErrorRpcProtocal extends ErrorRpc { exitCode = sysexits.PROTOCOL; } +class ErrorRpcMessageLength extends ErrorRpc { + static description = 'RPC Message exceeds maximum size'; + exitCode = sysexits.DATAERR; +} + export { ErrorRpc, ErrorRpcRunning, @@ -46,4 +51,5 @@ export { ErrorRpcParse, ErrorRpcHandlerMissing, ErrorRpcProtocal, + ErrorRpcMessageLength, }; diff --git a/src/RPC/utils.ts b/src/RPC/utils.ts index e64b09474..f8a040b66 100644 --- a/src/RPC/utils.ts +++ b/src/RPC/utils.ts @@ -17,9 +17,14 @@ import * as rpcErrors from './errors'; import * as rpcUtils from './utils'; import * as utils from '../utils'; import * as validationErrors from '../validation/errors'; +import { ErrorRpcMessageLength } from "./errors"; const jsonStreamParsers = require('@streamparser/json'); class JsonToJsonMessage implements Transformer { + protected bytesWritten: number = 0; + + constructor(protected byteLimit: number) {} + protected parser = new jsonStreamParsers.JSONParser({ separator: '', paths: ['$'], @@ -29,6 +34,7 @@ class JsonToJsonMessage implements Transformer { this.parser.onValue = (value) => { const jsonMessage = rpcUtils.parseJsonRpcMessage(value.value); controller.enqueue(jsonMessage); + this.bytesWritten = 0; }; }; @@ -37,17 +43,21 @@ class JsonToJsonMessage implements Transformer { _controller, ) => { try { + this.bytesWritten += chunk.byteLength; this.parser.write(chunk); } catch (e) { throw new rpcErrors.ErrorRpcParse(undefined, { cause: e }); } + if (this.bytesWritten > this.byteLimit) { + throw new rpcErrors.ErrorRpcMessageLength(); + } }; } // TODO: rename to something more descriptive? class JsonToJsonMessageStream extends TransformStream { - constructor() { - super(new JsonToJsonMessage()); + constructor(byteLimit: number = 1024 * 1024) { + super(new JsonToJsonMessage(byteLimit)); } } diff --git a/tests/RPC/utils.test.ts b/tests/RPC/utils.test.ts index 169b82ab6..e028c6e97 100644 --- a/tests/RPC/utils.test.ts +++ b/tests/RPC/utils.test.ts @@ -68,6 +68,29 @@ describe('utils tests', () => { { numRuns: 1000 }, ); + testProp( + 'Message size limit is enforced', + [ + fc.array(rpcTestUtils.jsonRpcRequestArb(fc.string({ minLength: 100 })), { + minLength: 1, + }), + ], + async (messages) => { + const parsedStream = rpcTestUtils + .jsonRpcStream(messages) + .pipeThrough(new rpcTestUtils.BufferStreamToSnippedStream([10])) + .pipeThrough(new rpcUtils.JsonToJsonMessageStream(50)); + + const doThing = async () => { + for await (const _ of parsedStream) { + // No touch, only consume + } + }; + await expect(doThing()).rejects.toThrow(rpcErrors.ErrorRpcMessageLength); + }, + { numRuns: 1000 }, + ); + // TODO: // - Test for badly structured data }); From b26e86ec5f55678b85d2b79056f175f01fb47466 Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Wed, 18 Jan 2023 17:55:24 +1100 Subject: [PATCH 13/44] fix: switched client and server to `CreateDestroy` Related #501 Related #502 [ci skip] --- src/RPC/RPCClient.ts | 22 ++++++++++------------ src/RPC/RPCServer.ts | 38 ++++++++++---------------------------- src/RPC/utils.ts | 1 - 3 files changed, 20 insertions(+), 41 deletions(-) diff --git a/src/RPC/RPCClient.ts b/src/RPC/RPCClient.ts index 4b8a0e319..5c1a64858 100644 --- a/src/RPC/RPCClient.ts +++ b/src/RPC/RPCClient.ts @@ -7,13 +7,13 @@ import type { } from './types'; import type { PromiseCancellable } from '@matrixai/async-cancellable'; import type { JSONValue, POJO } from 'types'; -import { StartStop } from '@matrixai/async-init/dist/StartStop'; +import { CreateDestroy, ready } from '@matrixai/async-init/dist/CreateDestroy'; import Logger from '@matrixai/logger'; import * as rpcErrors from './errors'; import * as rpcUtils from './utils'; -interface RPCClient extends StartStop {} -@StartStop() +interface RPCClient extends CreateDestroy {} +@CreateDestroy() class RPCClient { static async createRPCClient({ streamPairCreateCallback, @@ -27,7 +27,6 @@ class RPCClient { streamPairCreateCallback, logger, }); - await rpcClient.start(); logger.info(`Created ${this.name}`); return rpcClient; } @@ -47,22 +46,18 @@ class RPCClient { this.streamPairCreateCallback = streamPairCreateCallback; } - public async start(): Promise { - this.logger.info(`Starting ${this.constructor.name}`); - this.logger.info(`Started ${this.constructor.name}`); - } - - public async stop(): Promise { - this.logger.info(`Stopping ${this.constructor.name}`); + public async destroy(): Promise { + this.logger.info(`Destroying ${this.constructor.name}`); for await (const [stream] of this.activeStreams.entries()) { stream.cancel(new rpcErrors.ErrorRpcStopping()); } for await (const [stream] of this.activeStreams.entries()) { await stream; } - this.logger.info(`Stopped ${this.constructor.name}`); + this.logger.info(`Destroyed ${this.constructor.name}`); } + @ready(new rpcErrors.ErrorRpcDestroyed()) public async duplexStreamCaller( method: string, _metadata: POJO, @@ -134,6 +129,7 @@ class RPCClient { }; } + @ready(new rpcErrors.ErrorRpcDestroyed()) public async serverStreamCaller( method: string, parameters: I, @@ -156,6 +152,7 @@ class RPCClient { }; } + @ready(new rpcErrors.ErrorRpcDestroyed()) public async clientStreamCaller( method: string, metadata: POJO, @@ -182,6 +179,7 @@ class RPCClient { }; } + @ready(new rpcErrors.ErrorRpcDestroyed()) public async unaryCaller( method: string, parameters: I, diff --git a/src/RPC/RPCServer.ts b/src/RPC/RPCServer.ts index a55b19ab8..8335d623b 100644 --- a/src/RPC/RPCServer.ts +++ b/src/RPC/RPCServer.ts @@ -12,22 +12,15 @@ import type { JSONValue, POJO } from '../types'; import type { ConnectionInfo } from '../network/types'; import type { UnaryHandler } from './types'; import { ReadableStream } from 'stream/web'; -import { - CreateDestroyStartStop, - ready, -} from '@matrixai/async-init/dist/CreateDestroyStartStop'; +import { CreateDestroy, ready } from '@matrixai/async-init/dist/CreateDestroy'; import Logger from '@matrixai/logger'; import { PromiseCancellable } from '@matrixai/async-cancellable'; import * as rpcErrors from './errors'; import * as rpcUtils from './utils'; import * as grpcUtils from '../grpc/utils'; -// FIXME: Might need to be StartStop. Won't know for sure until it's used. -interface RPCServer extends CreateDestroyStartStop {} -@CreateDestroyStartStop( - new rpcErrors.ErrorRpcRunning(), - new rpcErrors.ErrorRpcDestroyed(), -) +interface RPCServer extends CreateDestroy {} +@CreateDestroy() class RPCServer { static async createRPCServer({ container, @@ -41,7 +34,6 @@ class RPCServer { container, logger, }); - await rpcServer.start(); logger.info(`Created ${this.name}`); return rpcServer; } @@ -64,27 +56,17 @@ class RPCServer { this.logger = logger; } - public async start(): Promise { - this.logger.info(`Starting ${this.constructor.name}`); - this.logger.info(`Started ${this.constructor.name}`); - } - - public async stop(): Promise { - this.logger.info(`Stopping ${this.constructor.name}`); + public async destroy(): Promise { + this.logger.info(`Destroying ${this.constructor.name}`); // Stopping any active steams const activeStreams = this.activeStreams; for await (const [activeStream] of activeStreams.entries()) { activeStream.cancel(new rpcErrors.ErrorRpcStopping()); } - this.logger.info(`Stopped ${this.constructor.name}`); - } - - public async destroy(): Promise { - this.logger.info(`Destroying ${this.constructor.name}`); this.logger.info(`Destroyed ${this.constructor.name}`); } - @ready(new rpcErrors.ErrorRpcNotRunning()) + @ready(new rpcErrors.ErrorRpcDestroyed()) public registerDuplexStreamHandler( method: string, handler: DuplexStreamHandler, @@ -92,7 +74,7 @@ class RPCServer { this.handlerMap.set(method, handler); } - @ready(new rpcErrors.ErrorRpcNotRunning()) + @ready(new rpcErrors.ErrorRpcDestroyed()) public registerUnaryHandler( method: string, handler: UnaryHandler, @@ -111,7 +93,7 @@ class RPCServer { this.handlerMap.set(method, wrapperDuplex); } - @ready(new rpcErrors.ErrorRpcNotRunning()) + @ready(new rpcErrors.ErrorRpcDestroyed()) public registerServerStreamHandler( method: string, handler: ServerStreamHandler, @@ -130,7 +112,7 @@ class RPCServer { this.handlerMap.set(method, wrapperDuplex); } - @ready(new rpcErrors.ErrorRpcNotRunning()) + @ready(new rpcErrors.ErrorRpcDestroyed()) public registerClientStreamHandler( method: string, handler: ClientStreamHandler, @@ -146,7 +128,7 @@ class RPCServer { this.handlerMap.set(method, wrapperDuplex); } - @ready(new rpcErrors.ErrorRpcNotRunning()) + @ready(new rpcErrors.ErrorRpcDestroyed()) public handleStream( streamPair: ReadableWritablePair, connectionInfo: ConnectionInfo, diff --git a/src/RPC/utils.ts b/src/RPC/utils.ts index f8a040b66..9edd3e30e 100644 --- a/src/RPC/utils.ts +++ b/src/RPC/utils.ts @@ -17,7 +17,6 @@ import * as rpcErrors from './errors'; import * as rpcUtils from './utils'; import * as utils from '../utils'; import * as validationErrors from '../validation/errors'; -import { ErrorRpcMessageLength } from "./errors"; const jsonStreamParsers = require('@streamparser/json'); class JsonToJsonMessage implements Transformer { From a4015942acb6763ad00152435d0d69a80dd6df9c Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Thu, 19 Jan 2023 14:35:49 +1100 Subject: [PATCH 14/44] fix: fixing up message types and parsing There is now a reasonably enforced hierarchy of `message` => `request` | `response` => `requestMessage` | `requestNotification` | `responseResult` | `responseError`. Related #500 Related #501 [ci skip] --- src/RPC/RPCClient.ts | 17 ++- src/RPC/RPCServer.ts | 17 +-- src/RPC/types.ts | 36 +++--- src/RPC/utils.ts | 251 ++++++++++++++++++++---------------- tests/RPC/RPCClient.test.ts | 24 ++-- tests/RPC/RPCServer.test.ts | 31 +++-- tests/RPC/utils.test.ts | 30 +++-- tests/RPC/utils.ts | 135 +++++++++---------- 8 files changed, 288 insertions(+), 253 deletions(-) diff --git a/src/RPC/RPCClient.ts b/src/RPC/RPCClient.ts index 5c1a64858..3793c00c1 100644 --- a/src/RPC/RPCClient.ts +++ b/src/RPC/RPCClient.ts @@ -1,7 +1,7 @@ import type { ClientCallerInterface, DuplexCallerInterface, - JsonRpcRequest, + JsonRpcRequestMessage, ServerCallerInterface, StreamPairCreateCallback, } from './types'; @@ -64,7 +64,7 @@ class RPCClient { ): Promise> { const streamPair = await this.streamPairCreateCallback(); const inputStream = streamPair.readable.pipeThrough( - new rpcUtils.JsonToJsonMessageStream(), + new rpcUtils.JsonToJsonMessageStream(rpcUtils.parseJsonRpcResponse), ); const outputTransform = new rpcUtils.JsonMessageToJsonStream(); void outputTransform.readable.pipeTo(streamPair.writable); @@ -75,9 +75,8 @@ class RPCClient { try { while (true) { value = yield; - const message: JsonRpcRequest = { + const message: JsonRpcRequestMessage = { method, - type: 'JsonRpcRequest', jsonrpc: '2.0', id: null, params: value, @@ -96,12 +95,12 @@ class RPCClient { while (true) { const { value, done } = await reader.read(); if (done) break; - if ( - value?.type === 'JsonRpcRequest' || - value?.type === 'JsonRpcNotification' - ) { - yield value.params as O; + if ('error' in value) { + throw Error('TMP message was an error message', { + cause: value.error, + }); } + yield value.result as O; } }; const output = outputGen(); diff --git a/src/RPC/RPCServer.ts b/src/RPC/RPCServer.ts index 8335d623b..1e683dd61 100644 --- a/src/RPC/RPCServer.ts +++ b/src/RPC/RPCServer.ts @@ -17,7 +17,6 @@ import Logger from '@matrixai/logger'; import { PromiseCancellable } from '@matrixai/async-cancellable'; import * as rpcErrors from './errors'; import * as rpcUtils from './utils'; -import * as grpcUtils from '../grpc/utils'; interface RPCServer extends CreateDestroy {} @CreateDestroy() @@ -151,18 +150,10 @@ class RPCServer { // a generator. const inputGen = async function* () { const pojoStream = streamPair.readable.pipeThrough( - new rpcUtils.JsonToJsonMessageStream(), + new rpcUtils.JsonToJsonMessageStream(rpcUtils.parseJsonRpcRequest), ); for await (const dataMessage of pojoStream) { - // FIXME: don't bother filtering, we should assume all input messages are request or notification. - // These should be checked by parsing, no need for a type field. - // Filtering for request and notification messages - if ( - dataMessage.type === 'JsonRpcRequest' || - dataMessage.type === 'JsonRpcNotification' - ) { - yield dataMessage; - } + yield dataMessage; } }; const container = this.container; @@ -203,7 +194,6 @@ class RPCServer { ctx, )) { const responseMessage: JsonRpcResponseResult = { - type: 'JsonRpcResponseResult', jsonrpc: '2.0', result: response, id: null, @@ -216,10 +206,9 @@ class RPCServer { const rpcError: JsonRpcError = { code: e.exitCode, message: e.description, - data: grpcUtils.fromError(e), + data: rpcUtils.fromError(e), }; const rpcErrorMessage: JsonRpcResponseError = { - type: 'JsonRpcResponseError', jsonrpc: '2.0', error: rpcError, id: null, diff --git a/src/RPC/types.ts b/src/RPC/types.ts index 712b96d8a..3e1cfaa66 100644 --- a/src/RPC/types.ts +++ b/src/RPC/types.ts @@ -6,8 +6,7 @@ import type { ReadableWritablePair } from 'stream/web'; /** * This is the JSON RPC request object. this is the generic message type used for the RPC. */ -type JsonRpcRequest = { - type: 'JsonRpcRequest'; +type JsonRpcRequestMessage = { // A String specifying the version of the JSON-RPC protocol. MUST be exactly "2.0" jsonrpc: '2.0'; // A String containing the name of the method to be invoked. Method names that begin with the word rpc followed by a @@ -23,8 +22,7 @@ type JsonRpcRequest = { id: string | number | null; }; -type JsonRpcNotification = { - type: 'JsonRpcNotification'; +type JsonRpcRequestNotification = { // A String specifying the version of the JSON-RPC protocol. MUST be exactly "2.0" jsonrpc: '2.0'; // A String containing the name of the method to be invoked. Method names that begin with the word rpc followed by a @@ -37,7 +35,6 @@ type JsonRpcNotification = { }; type JsonRpcResponseResult = { - type: 'JsonRpcResponseResult'; // A String specifying the version of the JSON-RPC protocol. MUST be exactly "2.0". jsonrpc: '2.0'; // This member is REQUIRED on success. @@ -51,14 +48,13 @@ type JsonRpcResponseResult = { id: string | number | null; }; -type JsonRpcResponseError = { - type: 'JsonRpcResponseError'; +type JsonRpcResponseError = { // A String specifying the version of the JSON-RPC protocol. MUST be exactly "2.0". jsonrpc: '2.0'; // This member is REQUIRED on error. // This member MUST NOT exist if there was no error triggered during invocation. // The value for this member MUST be an Object as defined in section 5.1. - error: JsonRpcError; + error: JsonRpcError; // This member is REQUIRED. // It MUST be the same as the value of the id member in the Request Object. // If there was an error in detecting the id in the Request object (e.g. Parse error/Invalid Request), @@ -78,7 +74,7 @@ type JsonRpcResponseError = { // -32603 Internal error Internal JSON-RPC error. // -32000 to -32099 -type JsonRpcError = { +type JsonRpcError = { // A Number that indicates the error type that occurred. // This MUST be an integer. code: number; @@ -88,19 +84,20 @@ type JsonRpcError = { // A Primitive or Structured value that contains additional information about the error. // This may be omitted. // The value of this member is defined by the Server (e.g. detailed error information, nested errors etc.). - data?: T; + data?: JSONValue; }; -type JsonRpcResponse< - T extends JSONValue | unknown = unknown, - K extends JSONValue | unknown = unknown, -> = JsonRpcResponseResult | JsonRpcResponseError; +type JsonRpcRequest = + | JsonRpcRequestMessage + | JsonRpcRequestNotification; + +type JsonRpcResponse = + | JsonRpcResponseResult + | JsonRpcResponseError; type JsonRpcMessage = | JsonRpcRequest - | JsonRpcNotification - | JsonRpcResponseResult - | JsonRpcResponseError; + | JsonRpcResponse; // Handler types type Handler = ( @@ -165,11 +162,12 @@ type StreamPairCreateCallback = () => Promise< >; export type { - JsonRpcRequest, - JsonRpcNotification, + JsonRpcRequestMessage, + JsonRpcRequestNotification, JsonRpcResponseResult, JsonRpcResponseError, JsonRpcError, + JsonRpcRequest, JsonRpcResponse, JsonRpcMessage, DuplexStreamHandler, diff --git a/src/RPC/utils.ts b/src/RPC/utils.ts index 9edd3e30e..8881b40af 100644 --- a/src/RPC/utils.ts +++ b/src/RPC/utils.ts @@ -6,41 +6,44 @@ import type { import type { JsonRpcError, JsonRpcMessage, - JsonRpcNotification, - JsonRpcRequest, + JsonRpcRequestNotification, + JsonRpcRequestMessage, JsonRpcResponseError, JsonRpcResponseResult, + JsonRpcRequest, + JsonRpcResponse, } from 'RPC/types'; import type { JSONValue } from '../types'; import { TransformStream } from 'stream/web'; import * as rpcErrors from './errors'; -import * as rpcUtils from './utils'; import * as utils from '../utils'; import * as validationErrors from '../validation/errors'; const jsonStreamParsers = require('@streamparser/json'); -class JsonToJsonMessage implements Transformer { +class JsonToJsonMessage + implements Transformer +{ protected bytesWritten: number = 0; - constructor(protected byteLimit: number) {} + constructor( + protected messageParser: (message: unknown) => T, + protected byteLimit: number, + ) {} protected parser = new jsonStreamParsers.JSONParser({ separator: '', paths: ['$'], }); - start: TransformerStartCallback = async (controller) => { + start: TransformerStartCallback = async (controller) => { this.parser.onValue = (value) => { - const jsonMessage = rpcUtils.parseJsonRpcMessage(value.value); + const jsonMessage = this.messageParser(value.value); controller.enqueue(jsonMessage); this.bytesWritten = 0; }; }; - transform: TransformerTransformCallback = async ( - chunk, - _controller, - ) => { + transform: TransformerTransformCallback = async (chunk) => { try { this.bytesWritten += chunk.byteLength; this.parser.write(chunk); @@ -54,9 +57,15 @@ class JsonToJsonMessage implements Transformer { } // TODO: rename to something more descriptive? -class JsonToJsonMessageStream extends TransformStream { - constructor(byteLimit: number = 1024 * 1024) { - super(new JsonToJsonMessage(byteLimit)); +class JsonToJsonMessageStream extends TransformStream< + Buffer, + T +> { + constructor( + messageParser: (message: unknown) => T, + byteLimit: number = 1024 * 1024, + ) { + super(new JsonToJsonMessage(messageParser, byteLimit)); } } @@ -82,17 +91,6 @@ function parseJsonRpcRequest( if (!utils.isObject(message)) { throw new validationErrors.ErrorParse('must be a JSON POJO'); } - if (!('type' in message)) { - throw new validationErrors.ErrorParse('`type` property must be defined'); - } - if (typeof message.type !== 'string') { - throw new validationErrors.ErrorParse('`type` property must be a string'); - } - if (message.type !== 'JsonRpcRequest') { - throw new validationErrors.ErrorParse( - '`type` property must be "JsonRpcRequest"', - ); - } if (!('method' in message)) { throw new validationErrors.ErrorParse('`method` property must be defined'); } @@ -102,51 +100,36 @@ function parseJsonRpcRequest( // If ('params' in message && !utils.isObject(message.params)) { // throw new validationErrors.ErrorParse('`params` property must be a POJO'); // } - if (!('id' in message)) { + return message as JsonRpcRequest; +} + +function parseJsonRpcRequestMessage( + message: unknown, +): JsonRpcRequestMessage { + const jsonRequest = parseJsonRpcRequest(message); + if (!('id' in jsonRequest)) { throw new validationErrors.ErrorParse('`id` property must be defined'); } if ( - typeof message.id !== 'string' && - typeof message.id !== 'number' && - message.id !== null + typeof jsonRequest.id !== 'string' && + typeof jsonRequest.id !== 'number' && + jsonRequest.id !== null ) { throw new validationErrors.ErrorParse( '`id` property must be a string, number or null', ); } - return message as JsonRpcRequest; + return jsonRequest as JsonRpcRequestMessage; } -function parseJsonRpcNotification( +function parseJsonRpcRequestNotification( message: unknown, -): JsonRpcNotification { - if (!utils.isObject(message)) { - throw new validationErrors.ErrorParse('must be a JSON POJO'); - } - if (!('type' in message)) { - throw new validationErrors.ErrorParse('`type` property must be defined'); - } - if (typeof message.type !== 'string') { - throw new validationErrors.ErrorParse('`type` property must be a string'); - } - if (message.type !== 'JsonRpcNotification') { - throw new validationErrors.ErrorParse( - '`type` property must be "JsonRpcRequest"', - ); - } - if (!('method' in message)) { - throw new validationErrors.ErrorParse('`method` property must be defined'); - } - if (typeof message.method !== 'string') { - throw new validationErrors.ErrorParse('`method` property must be a string'); - } - // If ('params' in message && !utils.isObject(message.params)) { - // throw new validationErrors.ErrorParse('`params` property must be a POJO'); - // } - if ('id' in message) { +): JsonRpcRequestNotification { + const jsonRequest = parseJsonRpcRequest(message); + if ('id' in jsonRequest) { throw new validationErrors.ErrorParse('`id` property must not be defined'); } - return message as JsonRpcNotification; + return jsonRequest as JsonRpcRequestNotification; } function parseJsonRpcResponseResult( @@ -155,17 +138,6 @@ function parseJsonRpcResponseResult( if (!utils.isObject(message)) { throw new validationErrors.ErrorParse('must be a JSON POJO'); } - if (!('type' in message)) { - throw new validationErrors.ErrorParse('`type` property must be defined'); - } - if (typeof message.type !== 'string') { - throw new validationErrors.ErrorParse('`type` property must be a string'); - } - if (message.type !== 'JsonRpcResponseResult') { - throw new validationErrors.ErrorParse( - '`type` property must be "JsonRpcRequest"', - ); - } if (!('result' in message)) { throw new validationErrors.ErrorParse('`result` property must be defined'); } @@ -192,23 +164,10 @@ function parseJsonRpcResponseResult( return message as JsonRpcResponseResult; } -function parseJsonRpcResponseError( - message: unknown, -): JsonRpcResponseError { +function parseJsonRpcResponseError(message: unknown): JsonRpcResponseError { if (!utils.isObject(message)) { throw new validationErrors.ErrorParse('must be a JSON POJO'); } - if (!('type' in message)) { - throw new validationErrors.ErrorParse('`type` property must be defined'); - } - if (typeof message.type !== 'string') { - throw new validationErrors.ErrorParse('`type` property must be a string'); - } - if (message.type !== 'JsonRpcResponseError') { - throw new validationErrors.ErrorParse( - '`type` property must be "JsonRpcResponseError"', - ); - } if ('result' in message) { throw new validationErrors.ErrorParse( '`result` property must not be defined', @@ -217,7 +176,7 @@ function parseJsonRpcResponseError( if (!('error' in message)) { throw new validationErrors.ErrorParse('`error` property must be defined'); } - parseJsonRpcError(message.error); + parseJsonRpcError(message.error); if (!('id' in message)) { throw new validationErrors.ErrorParse('`id` property must be defined'); } @@ -230,12 +189,10 @@ function parseJsonRpcResponseError( '`id` property must be a string, number or null', ); } - return message as JsonRpcResponseError; + return message as JsonRpcResponseError; } -function parseJsonRpcError( - message: unknown, -): JsonRpcError { +function parseJsonRpcError(message: unknown): JsonRpcError { if (!utils.isObject(message)) { throw new validationErrors.ErrorParse('must be a JSON POJO'); } @@ -256,20 +213,35 @@ function parseJsonRpcError( // If ('data' in message && !utils.isObject(message.data)) { // throw new validationErrors.ErrorParse('`data` property must be a POJO'); // } - return message as JsonRpcError; + return message as JsonRpcError; } -function parseJsonRpcMessage( +function parseJsonRpcResponse( message: unknown, -): JsonRpcMessage { +): JsonRpcResponse { if (!utils.isObject(message)) { throw new validationErrors.ErrorParse('must be a JSON POJO'); } - if (!('type' in message)) { - throw new validationErrors.ErrorParse('`type` property must be defined'); + try { + return parseJsonRpcResponseResult(message); + } catch (e) { + // Do nothing } - if (typeof message.type !== 'string') { - throw new validationErrors.ErrorParse('`type` property must be a string'); + try { + return parseJsonRpcResponseError(message); + } catch (e) { + // Do nothing + } + throw new validationErrors.ErrorParse( + 'structure did not match a `JsonRpcResponse`', + ); +} + +function parseJsonRpcMessage( + message: unknown, +): JsonRpcMessage { + if (!utils.isObject(message)) { + throw new validationErrors.ErrorParse('must be a JSON POJO'); } if (!('jsonrpc' in message)) { throw new validationErrors.ErrorParse('`jsonrpc` property must be defined'); @@ -279,19 +251,79 @@ function parseJsonRpcMessage( '`jsonrpc` property must be a string of "2.0"', ); } - switch (message.type) { - case 'JsonRpcRequest': - return parseJsonRpcRequest(message); - case 'JsonRpcNotification': - return parseJsonRpcNotification(message); - case 'JsonRpcResponseResult': - return parseJsonRpcResponseResult(message); - case 'JsonRpcResponseError': - return parseJsonRpcResponseError(message); - default: - throw new validationErrors.ErrorParse( - '`type` property must be a valid type', - ); + try { + return parseJsonRpcRequest(message); + } catch { + // Do nothing + } + try { + return parseJsonRpcResponse(message); + } catch { + // Do nothing + } + throw new validationErrors.ErrorParse( + 'Message structure did not match a `JsonRpcMessage`', + ); +} + +/** + * Replacer function for serialising errors over GRPC (used by `JSON.stringify` + * in `fromError`) + * Polykey errors are handled by their inbuilt `toJSON` method , so this only + * serialises other errors + */ +function replacer(key: string, value: any): any { + if (value instanceof AggregateError) { + // AggregateError has an `errors` property + return { + type: value.constructor.name, + data: { + errors: value.errors, + message: value.message, + stack: value.stack, + }, + }; + } else if (value instanceof Error) { + // If it's some other type of error then only serialise the message and + // stack (and the type of the error) + return { + type: value.name, + data: { + message: value.message, + stack: value.stack, + }, + }; + } else { + // If it's not an error then just leave as is + return value; + } +} + +/** + * The same as `replacer`, however this will additionally filter out any + * sensitive data that should not be sent over the network when sending to an + * agent (as opposed to a client) + */ +function sensitiveReplacer(key: string, value: any) { + if (key === 'stack') { + return; + } else { + return replacer(key, value); + } +} + +/** + * Serializes Error instances into GRPC errors + * Use this on the sending side to send exceptions + * Do not send exceptions to clients you do not trust + * If sending to an agent (rather than a client), set sensitive to true to + * prevent sensitive information from being sent over the network + */ +function fromError(error: Error, sensitive: boolean = false) { + if (sensitive) { + return { error: JSON.stringify(error, sensitiveReplacer) }; + } else { + return { error: JSON.stringify(error, replacer) }; } } @@ -299,8 +331,11 @@ export { JsonToJsonMessageStream, JsonMessageToJsonStream, parseJsonRpcRequest, - parseJsonRpcNotification, + parseJsonRpcRequestMessage, + parseJsonRpcRequestNotification, parseJsonRpcResponseResult, parseJsonRpcResponseError, + parseJsonRpcResponse, parseJsonRpcMessage, + fromError, }; diff --git a/tests/RPC/RPCClient.test.ts b/tests/RPC/RPCClient.test.ts index dbb51cba5..8cfc69732 100644 --- a/tests/RPC/RPCClient.test.ts +++ b/tests/RPC/RPCClient.test.ts @@ -1,6 +1,6 @@ import type { ReadableWritablePair } from 'stream/web'; import type { JSONValue } from '@/types'; -import type { JsonRpcRequest } from '@/RPC/types'; +import type { JsonRpcRequestMessage } from '@/RPC/types'; import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; import { testProp, fc } from '@fast-check/jest'; import RPCClient from '@/RPC/RPCClient'; @@ -14,7 +14,7 @@ describe(`${RPCClient.name}`, () => { const methodName = 'testMethod'; const specificMessageArb = fc - .array(rpcTestUtils.jsonRpcRequestArb(), { + .array(rpcTestUtils.jsonRpcResponseResultArb(), { minLength: 5, }) .noShrink(); @@ -43,13 +43,12 @@ describe(`${RPCClient.name}`, () => { } await callerInterface.write(value); } - const expectedMessages: Array = messages.map((v) => { - const request: JsonRpcRequest = { - type: 'JsonRpcRequest', + const expectedMessages: Array = messages.map((v) => { + const request: JsonRpcRequestMessage = { jsonrpc: '2.0', method: methodName, id: null, - ...(v.params === undefined ? {} : { params: v.params }), + ...(v.result === undefined ? {} : { params: v.result }), }; return request; }); @@ -80,12 +79,11 @@ describe(`${RPCClient.name}`, () => { for await (const value of callerInterface.outputGenerator) { values.push(value); } - const expectedValues = messages.map((v) => v.params); + const expectedValues = messages.map((v) => v.result); expect(values).toStrictEqual(expectedValues); expect((await outputResult)[0]?.toString()).toStrictEqual( JSON.stringify({ method: methodName, - type: 'JsonRpcRequest', jsonrpc: '2.0', id: null, params, @@ -95,7 +93,7 @@ describe(`${RPCClient.name}`, () => { ); testProp( 'generic client stream caller', - [rpcTestUtils.jsonRpcRequestArb(), fc.array(fc.jsonValue())], + [rpcTestUtils.jsonRpcResponseResultArb(), fc.array(fc.jsonValue())], async (message, params) => { const inputStream = rpcTestUtils.jsonRpcStream([message]); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); @@ -115,11 +113,10 @@ describe(`${RPCClient.name}`, () => { await callerInterface.write(param as JSONValue); } await callerInterface.end(); - expect(await callerInterface.result).toStrictEqual(message.params); + expect(await callerInterface.result).toStrictEqual(message.result); const expectedOutput = params.map((v) => JSON.stringify({ method: methodName, - type: 'JsonRpcRequest', jsonrpc: '2.0', id: null, params: v, @@ -132,7 +129,7 @@ describe(`${RPCClient.name}`, () => { ); testProp( 'generic unary caller', - [rpcTestUtils.jsonRpcRequestArb(), fc.jsonValue()], + [rpcTestUtils.jsonRpcResponseResultArb(), fc.jsonValue()], async (message, params) => { const inputStream = rpcTestUtils.jsonRpcStream([message]); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); @@ -149,11 +146,10 @@ describe(`${RPCClient.name}`, () => { params as JSONValue, {}, ); - expect(result).toStrictEqual(message.params); + expect(result).toStrictEqual(message.result); expect((await outputResult)[0]?.toString()).toStrictEqual( JSON.stringify({ method: methodName, - type: 'JsonRpcRequest', jsonrpc: '2.0', id: null, params: params, diff --git a/tests/RPC/RPCServer.test.ts b/tests/RPC/RPCServer.test.ts index 528c0565b..0df1a8497 100644 --- a/tests/RPC/RPCServer.test.ts +++ b/tests/RPC/RPCServer.test.ts @@ -12,7 +12,6 @@ import type { ReadableWritablePair } from 'stream/web'; import { testProp, fc } from '@fast-check/jest'; import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; import RPCServer from '@/RPC/RPCServer'; -import * as rpcErrors from '@/RPC/errors'; import * as rpcTestUtils from './utils'; describe(`${RPCServer.name}`, () => { @@ -22,7 +21,7 @@ describe(`${RPCServer.name}`, () => { const methodName = 'testMethod'; const specificMessageArb = fc - .array(rpcTestUtils.jsonRpcRequestArb(fc.constant(methodName)), { + .array(rpcTestUtils.jsonRpcRequestMessageArb(fc.constant(methodName)), { minLength: 5, }) .noShrink(); @@ -83,7 +82,7 @@ describe(`${RPCServer.name}`, () => { ); const singleNumberMessageArb = fc.array( - rpcTestUtils.jsonRpcRequestArb( + rpcTestUtils.jsonRpcRequestMessageArb( fc.constant(methodName), fc.integer({ min: 1, max: 20 }), ), @@ -224,18 +223,19 @@ describe(`${RPCServer.name}`, () => { const [outputResult, outputStream] = rpcTestUtils.streamToArray(); let thing; let lastMessage: JsonRpcMessage | undefined; - const tapStream = new rpcTestUtils.TapStream( - async (_, iteration) => { - if (iteration === 2) { - // @ts-ignore: kidnap private property - const activeStreams = rpc.activeStreams.values(); - for (const activeStream of activeStreams) { - thing = activeStream; - activeStream.cancel(new rpcErrors.ErrorRpcStopping()); - } - } - }, - ); + const tapStream: any = {}; + // Const tapStream = new rpcTestUtils.TapStream( + // async (_, iteration) => { + // if (iteration === 2) { + // // @ts-ignore: kidnap private property + // const activeStreams = rpc.activeStreams.values(); + // for (const activeStream of activeStreams) { + // thing = activeStream; + // activeStream.cancel(new rpcErrors.ErrorRpcStopping()); + // } + // } + // }, + // ); await tapStream.readable.pipeTo(outputStream); const readWriteStream: ReadableWritablePair = { readable: stream, @@ -255,7 +255,6 @@ describe(`${RPCServer.name}`, () => { await expect(thing).toResolve(); // Last message should be an error message expect(lastMessage).toBeDefined(); - expect(lastMessage?.type).toBe('JsonRpcResponseError'); }, ); diff --git a/tests/RPC/utils.test.ts b/tests/RPC/utils.test.ts index e028c6e97..bd737b505 100644 --- a/tests/RPC/utils.test.ts +++ b/tests/RPC/utils.test.ts @@ -12,7 +12,9 @@ describe('utils tests', () => { async (messages) => { const parsedStream = rpcTestUtils .jsonRpcStream(messages) - .pipeThrough(new rpcUtils.JsonToJsonMessageStream()); // Converting back. + .pipeThrough( + new rpcUtils.JsonToJsonMessageStream(rpcUtils.parseJsonRpcMessage), + ); // Converting back. const asd = await AsyncIterable.as(parsedStream).toArray(); expect(asd).toEqual(messages); @@ -27,7 +29,9 @@ describe('utils tests', () => { const parsedStream = rpcTestUtils .jsonRpcStream(messages) .pipeThrough(new rpcTestUtils.BufferStreamToSnippedStream(snippattern)) // Imaginary internet here - .pipeThrough(new rpcUtils.JsonToJsonMessageStream()); // Converting back. + .pipeThrough( + new rpcUtils.JsonToJsonMessageStream(rpcUtils.parseJsonRpcMessage), + ); // Converting back. const asd = await AsyncIterable.as(parsedStream).toArray(); expect(asd).toStrictEqual(messages); @@ -50,7 +54,9 @@ describe('utils tests', () => { .jsonRpcStream(messages) .pipeThrough(new rpcTestUtils.BufferStreamToSnippedStream(snippattern)) // Imaginary internet here .pipeThrough(new rpcTestUtils.BufferStreamToNoisyStream(noise)) // Adding bad data to the stream - .pipeThrough(new rpcUtils.JsonToJsonMessageStream()); // Converting back. + .pipeThrough( + new rpcUtils.JsonToJsonMessageStream(rpcUtils.parseJsonRpcMessage), + ); // Converting back. await expect(AsyncIterable.as(parsedStream).toArray()).rejects.toThrow( rpcErrors.ErrorRpcParse, @@ -61,7 +67,7 @@ describe('utils tests', () => { testProp( 'can parse messages', - [rpcTestUtils.jsonRpcMessageArb], + [rpcTestUtils.jsonRpcMessageArb()], async (message) => { rpcUtils.parseJsonRpcMessage(message); }, @@ -71,15 +77,23 @@ describe('utils tests', () => { testProp( 'Message size limit is enforced', [ - fc.array(rpcTestUtils.jsonRpcRequestArb(fc.string({ minLength: 100 })), { - minLength: 1, - }), + fc.array( + rpcTestUtils.jsonRpcRequestMessageArb(fc.string({ minLength: 100 })), + { + minLength: 1, + }, + ), ], async (messages) => { const parsedStream = rpcTestUtils .jsonRpcStream(messages) .pipeThrough(new rpcTestUtils.BufferStreamToSnippedStream([10])) - .pipeThrough(new rpcUtils.JsonToJsonMessageStream(50)); + .pipeThrough( + new rpcUtils.JsonToJsonMessageStream( + rpcUtils.parseJsonRpcMessage, + 50, + ), + ); const doThing = async () => { for await (const _ of parsedStream) { diff --git a/tests/RPC/utils.ts b/tests/RPC/utils.ts index 81ce3913e..376f5705a 100644 --- a/tests/RPC/utils.ts +++ b/tests/RPC/utils.ts @@ -7,10 +7,12 @@ import type { POJO } from '@/types'; import type { JsonRpcError, JsonRpcMessage, - JsonRpcNotification, - JsonRpcRequest, + JsonRpcRequestNotification, + JsonRpcRequestMessage, JsonRpcResponseError, JsonRpcResponseResult, + JsonRpcResponse, + JsonRpcRequest, } from '@/RPC/types'; import type { JsonValue } from 'fast-check'; import { ReadableStream, WritableStream, TransformStream } from 'stream/web'; @@ -100,11 +102,13 @@ const jsonRpcStream = (messages: Array) => { }); }; -const jsonRpcRequestArb = ( +const safeJsonValueArb = fc + .jsonValue() + .map((value) => JSON.parse(JSON.stringify(value))); + +const jsonRpcRequestMessageArb = ( method: fc.Arbitrary = fc.string(), - params: fc.Arbitrary = fc - .jsonValue() - .map((value) => JSON.parse(JSON.stringify(value))), + params: fc.Arbitrary = safeJsonValueArb, ) => fc .record( @@ -119,37 +123,55 @@ const jsonRpcRequestArb = ( requiredKeys: ['type', 'jsonrpc', 'method', 'id'], }, ) + .noShrink() as fc.Arbitrary; + +const jsonRpcRequestNotificationArb = ( + method: fc.Arbitrary = fc.string(), + params: fc.Arbitrary = safeJsonValueArb, +) => + fc + .record( + { + type: fc.constant('JsonRpcNotification'), + jsonrpc: fc.constant('2.0'), + method: method, + params: params, + }, + { + requiredKeys: ['type', 'jsonrpc', 'method'], + }, + ) + .noShrink() as fc.Arbitrary; + +const jsonRpcRequestArb = ( + method: fc.Arbitrary = fc.string(), + params: fc.Arbitrary = safeJsonValueArb, +) => + fc + .oneof( + jsonRpcRequestMessageArb(method, params), + jsonRpcRequestNotificationArb(method, params), + ) .noShrink() as fc.Arbitrary; -const jsonRpcNotificationArb = fc - .record( - { - type: fc.constant('JsonRpcNotification'), +const jsonRpcResponseResultArb = ( + result: fc.Arbitrary = safeJsonValueArb, +) => + fc + .record({ + type: fc.constant('JsonRpcResponseResult'), jsonrpc: fc.constant('2.0'), - method: fc.string(), - params: fc.jsonValue(), - }, - { - requiredKeys: ['type', 'jsonrpc', 'method'], - }, - ) - .noShrink() as fc.Arbitrary; - -const jsonRpcResponseResultArb = fc - .record({ - type: fc.constant('JsonRpcResponseResult'), - jsonrpc: fc.constant('2.0'), - result: fc.jsonValue(), - id: fc.oneof(fc.string(), fc.integer(), fc.constant(null)), - }) - .noShrink() as fc.Arbitrary; + result: result, + id: fc.oneof(fc.string(), fc.integer(), fc.constant(null)), + }) + .noShrink() as fc.Arbitrary; const jsonRpcErrorArb = fc .record( { code: fc.integer(), message: fc.string(), - data: fc.jsonValue(), + data: safeJsonValueArb, }, { requiredKeys: ['code', 'message'], @@ -166,21 +188,28 @@ const jsonRpcResponseErrorArb = fc }) .noShrink() as fc.Arbitrary; -const jsonRpcMessageArb = fc - .oneof( - jsonRpcRequestArb(), - jsonRpcNotificationArb, - jsonRpcResponseResultArb, - jsonRpcResponseErrorArb, - ) - .noShrink() as fc.Arbitrary; +const jsonRpcResponseArb = ( + result: fc.Arbitrary = safeJsonValueArb, +) => + fc + .oneof(jsonRpcResponseResultArb(result), jsonRpcResponseErrorArb) + .noShrink() as fc.Arbitrary; + +const jsonRpcMessageArb = ( + method: fc.Arbitrary = fc.string(), + params: fc.Arbitrary = safeJsonValueArb, + result: fc.Arbitrary = safeJsonValueArb, +) => + fc + .oneof(jsonRpcRequestArb(method, params), jsonRpcResponseArb(result)) + .noShrink() as fc.Arbitrary; const snippingPatternArb = fc .array(fc.integer({ min: 1, max: 32 }), { minLength: 100, size: 'medium' }) .noShrink(); const jsonMessagesArb = fc - .array(jsonRpcRequestArb(), { minLength: 2 }) + .array(jsonRpcRequestMessageArb(), { minLength: 2 }) .noShrink(); function streamToArray(): [Promise>, WritableStream] { @@ -200,44 +229,20 @@ function streamToArray(): [Promise>, WritableStream] { return [result.p, outputStream]; } -class Tap implements Transformer { - protected iteration = 0; - protected tapIterator; - - constructor(tapIterator: (chunk: T, iteration: number) => Promise) { - this.tapIterator = tapIterator; - } - - transform: TransformerTransformCallback = async (chunk, controller) => { - await this.tapIterator(chunk, this.iteration); - controller.enqueue(chunk); - this.iteration += 1; - }; -} - -/** - * This is used to convert regular chunks into randomly sized chunks based on - * a provided pattern. This is to replicate randomness introduced by packets - * splitting up the data. - */ -class TapStream extends TransformStream { - constructor(tapIterator: (chunk: T, iteration: number) => Promise) { - super(new Tap(tapIterator)); - } -} - export { BufferStreamToSnippedStream, BufferStreamToNoisyStream, jsonRpcStream, + safeJsonValueArb, + jsonRpcRequestMessageArb, + jsonRpcRequestNotificationArb, jsonRpcRequestArb, - jsonRpcNotificationArb, jsonRpcResponseResultArb, jsonRpcErrorArb, jsonRpcResponseErrorArb, + jsonRpcResponseArb, jsonRpcMessageArb, snippingPatternArb, jsonMessagesArb, streamToArray, - TapStream, }; From cf13564e1e38f145d155a840c74914d7a8704763 Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Thu, 19 Jan 2023 17:38:04 +1100 Subject: [PATCH 15/44] tests: client and server integration tests Related #500 Related #501 [ci skip] --- src/RPC/RPCServer.ts | 3 +- tests/RPC/RPC.test.ts | 175 ++++++++++++++++++++++++++++++++++++++++++ tests/RPC/utils.ts | 50 ++++++++++++ 3 files changed, 227 insertions(+), 1 deletion(-) create mode 100644 tests/RPC/RPC.test.ts diff --git a/src/RPC/RPCServer.ts b/src/RPC/RPCServer.ts index 1e683dd61..1a23569a4 100644 --- a/src/RPC/RPCServer.ts +++ b/src/RPC/RPCServer.ts @@ -170,8 +170,9 @@ class RPCServer { throw Error('TMP Stream closed early'); } const method = leadingMetadataMessage.value.method; - const _metadata = leadingMetadataMessage.value.params; + const initialParams = leadingMetadataMessage.value.params; const dataGen = async function* () { + yield initialParams as JSONValue; for await (const data of input) { yield data.params as JSONValue; } diff --git a/tests/RPC/RPC.test.ts b/tests/RPC/RPC.test.ts new file mode 100644 index 000000000..c7516001d --- /dev/null +++ b/tests/RPC/RPC.test.ts @@ -0,0 +1,175 @@ +import type { + ClientStreamHandler, + DuplexStreamHandler, + ServerStreamHandler, + UnaryHandler, +} from '@/RPC/types'; +import type { ConnectionInfo } from '@/network/types'; +import type { JSONValue } from '@/types'; +import { fc, testProp } from '@fast-check/jest'; +import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; +import RPCServer from '@/RPC/RPCServer'; +import RPCClient from '@/RPC/RPCClient'; +import * as rpcTestUtils from './utils'; + +describe('RPC', () => { + const logger = new Logger(`RPC Test`, LogLevel.WARN, [new StreamHandler()]); + + const methodName = 'testMethod'; + + testProp( + 'RPC communication with duplex stream', + [fc.array(rpcTestUtils.safeJsonValueArb, { minLength: 1 })], + async (values) => { + const { clientPair, serverPair } = rpcTestUtils.createTapPairs< + Buffer, + Buffer + >(); + + const container = {}; + const rpc = await RPCServer.createRPCServer({ container, logger }); + + const duplexHandler: DuplexStreamHandler = + async function* (input, _container, _connectionInfo, _ctx) { + for await (const val of input) { + yield val; + } + }; + + rpc.registerDuplexStreamHandler(methodName, duplexHandler); + rpc.handleStream(serverPair, {} as ConnectionInfo); + + const rpcClient = await RPCClient.createRPCClient({ + streamPairCreateCallback: async () => clientPair, + logger, + }); + + const callerInterface = await rpcClient.duplexStreamCaller( + methodName, + {}, + ); + for (const value of values) { + await callerInterface.write(value); + expect((await callerInterface.read()).value).toStrictEqual(value); + } + await callerInterface.end(); + expect((await callerInterface.read()).value).toBeUndefined(); + expect((await callerInterface.read()).done).toBeTrue(); + }, + ); + + testProp( + 'RPC communication with client stream', + [fc.integer({ min: 1, max: 100 })], + async (value) => { + const { clientPair, serverPair } = rpcTestUtils.createTapPairs< + Buffer, + Buffer + >(); + + const container = {}; + const rpc = await RPCServer.createRPCServer({ container, logger }); + + const serverStreamHandler: ServerStreamHandler = + async function* (input, _container, _connectionInfo, _ctx) { + for (let i = 0; i < input; i++) { + yield i; + } + }; + + rpc.registerServerStreamHandler(methodName, serverStreamHandler); + rpc.handleStream(serverPair, {} as ConnectionInfo); + + const rpcClient = await RPCClient.createRPCClient({ + streamPairCreateCallback: async () => clientPair, + logger, + }); + + const callerInterface = await rpcClient.serverStreamCaller< + number, + number + >(methodName, value, {}); + + const outputs: Array = []; + for await (const num of callerInterface.outputGenerator) { + outputs.push(num); + } + expect(outputs.length).toEqual(value); + }, + { numRuns: 1 }, + ); + + testProp( + 'RPC communication with server stream', + [fc.array(fc.integer(), { minLength: 1 })], + async (values) => { + const { clientPair, serverPair } = rpcTestUtils.createTapPairs< + Buffer, + Buffer + >(); + + const container = {}; + const rpc = await RPCServer.createRPCServer({ container, logger }); + + const clientStreamhandler: ClientStreamHandler = async ( + input, + ) => { + let acc = 0; + for await (const number of input) { + acc += number; + } + return acc; + }; + rpc.registerClientStreamHandler(methodName, clientStreamhandler); + rpc.handleStream(serverPair, {} as ConnectionInfo); + + const rpcClient = await RPCClient.createRPCClient({ + streamPairCreateCallback: async () => clientPair, + logger, + }); + + const callerInterface = await rpcClient.clientStreamCaller< + number, + number + >(methodName, {}); + for (const value of values) { + await callerInterface.write(value); + } + await callerInterface.end(); + + const expectedResult = values.reduce((p, c) => p + c); + await expect(callerInterface.result).resolves.toEqual(expectedResult); + }, + ); + + testProp( + 'RPC communication with unary call', + [rpcTestUtils.safeJsonValueArb], + async (value) => { + const { clientPair, serverPair } = rpcTestUtils.createTapPairs< + Buffer, + Buffer + >(); + + const container = {}; + const rpc = await RPCServer.createRPCServer({ container, logger }); + + const unaryCaller: UnaryHandler = async (input) => + input; + rpc.registerUnaryHandler(methodName, unaryCaller); + rpc.handleStream(serverPair, {} as ConnectionInfo); + + const rpcClient = await RPCClient.createRPCClient({ + streamPairCreateCallback: async () => clientPair, + logger, + }); + + const result = await rpcClient.unaryCaller( + methodName, + value, + {}, + ); + expect(result).toStrictEqual(value); + }, + ); +}); diff --git a/tests/RPC/utils.ts b/tests/RPC/utils.ts index 376f5705a..762084d5f 100644 --- a/tests/RPC/utils.ts +++ b/tests/RPC/utils.ts @@ -2,6 +2,7 @@ import type { Transformer, TransformerFlushCallback, TransformerTransformCallback, + ReadableWritablePair, } from 'stream/web'; import type { POJO } from '@/types'; import type { @@ -229,6 +230,53 @@ function streamToArray(): [Promise>, WritableStream] { return [result.p, outputStream]; } +class TapTransformer implements Transformer { + protected iteration = 0; + + constructor( + protected tapCallback: (chunk: I, iteration: number) => Promise, + ) {} + + transform: TransformerTransformCallback = async (chunk, controller) => { + await this.tapCallback(chunk, this.iteration); + controller.enqueue(chunk); + this.iteration += 1; + }; +} + +type TapCallback = (chunk: T, iteration: number) => Promise; + +/** + * This is used to convert regular chunks into randomly sized chunks based on + * a provided pattern. This is to replicate randomness introduced by packets + * splitting up the data. + */ +class TapTransformerStream extends TransformStream { + constructor(tapCallback: TapCallback = async () => {}) { + super(new TapTransformer(tapCallback)); + } +} + +function createTapPairs( + forwardTapCallback: TapCallback = async () => {}, + reverseTapCallback: TapCallback = async () => {}, +) { + const forwardTap = new TapTransformerStream(forwardTapCallback); + const reverseTap = new TapTransformerStream(reverseTapCallback); + const clientPair: ReadableWritablePair = { + readable: reverseTap.readable, + writable: forwardTap.writable, + }; + const serverPair: ReadableWritablePair = { + readable: forwardTap.readable, + writable: reverseTap.writable, + }; + return { + clientPair, + serverPair, + }; +} + export { BufferStreamToSnippedStream, BufferStreamToNoisyStream, @@ -245,4 +293,6 @@ export { snippingPatternArb, jsonMessagesArb, streamToArray, + TapTransformerStream, + createTapPairs, }; From f20b3429fe027c2834b639fda4ca3b13a5d6c376 Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Thu, 19 Jan 2023 19:54:28 +1100 Subject: [PATCH 16/44] feat: client error handling Related #501 [ci skip] --- src/RPC/RPCClient.ts | 54 ++++++++++++++----- src/RPC/RPCServer.ts | 9 ++-- src/RPC/errors.ts | 6 +++ src/RPC/utils.ts | 105 ++++++++++++++++++++++++++++++++++-- tests/RPC/RPCClient.test.ts | 38 +++++++++++++ tests/RPC/utils.ts | 57 ++++++++++---------- 6 files changed, 221 insertions(+), 48 deletions(-) diff --git a/src/RPC/RPCClient.ts b/src/RPC/RPCClient.ts index 3793c00c1..01a12631b 100644 --- a/src/RPC/RPCClient.ts +++ b/src/RPC/RPCClient.ts @@ -5,12 +5,13 @@ import type { ServerCallerInterface, StreamPairCreateCallback, } from './types'; -import type { PromiseCancellable } from '@matrixai/async-cancellable'; import type { JSONValue, POJO } from 'types'; +import { PromiseCancellable } from '@matrixai/async-cancellable'; import { CreateDestroy, ready } from '@matrixai/async-init/dist/CreateDestroy'; import Logger from '@matrixai/logger'; import * as rpcErrors from './errors'; import * as rpcUtils from './utils'; +import { promise } from '../utils/index'; interface RPCClient extends CreateDestroy {} @CreateDestroy() @@ -62,6 +63,22 @@ class RPCClient { method: string, _metadata: POJO, ): Promise> { + // Constructing the PromiseCancellable for tracking the active stream + const inputFinishedProm = promise(); + const outputFinishedProm = promise(); + const abortController = new AbortController(); + const handlerProm: PromiseCancellable = new PromiseCancellable( + (resolve) => { + Promise.all([inputFinishedProm.p, outputFinishedProm.p]).finally(() => + resolve(), + ); + }, + abortController, + ); + // Putting the PromiseCancellable into the active streams map + this.activeStreams.add(handlerProm); + void handlerProm.finally(() => this.activeStreams.delete(handlerProm)); + const streamPair = await this.streamPairCreateCallback(); const inputStream = streamPair.readable.pipeThrough( new rpcUtils.JsonToJsonMessageStream(rpcUtils.parseJsonRpcResponse), @@ -87,27 +104,32 @@ class RPCClient { await writer.abort(e); } finally { await writer.close(); + inputFinishedProm.resolveP(); } }; const outputGen = async function* (): AsyncGenerator { - const reader = inputStream.getReader(); - while (true) { - const { value, done } = await reader.read(); - if (done) break; - if ('error' in value) { - throw Error('TMP message was an error message', { - cause: value.error, - }); + try { + for await (const result of inputStream) { + if ('error' in result) { + throw rpcUtils.toError(result.error.data); + } + yield result.result as O; } - yield value.result as O; + } finally { + outputFinishedProm.resolveP(); } }; const output = outputGen(); const input = inputGen(); // Initiating the input generator await input.next(); - + // Hooking up abort signals + abortController.signal.addEventListener('abort', async () => { + await output.throw(abortController.signal.reason); + await input.throw(abortController.signal.reason); + }); + // Returning interface return { read: () => output.next(), write: async (value: I) => { @@ -163,7 +185,11 @@ class RPCClient { const output = callerInterface .read() .then(({ value, done }) => { - if (done) throw Error('TMP Stream closed early'); + if (done) { + throw new rpcErrors.ErrorRpcRemoteError( + 'Stream ended before response', + ); + } return value; }) .finally(async () => { @@ -190,7 +216,9 @@ class RPCClient { ); await callerInterface.write(parameters); const output = await callerInterface.read(); - if (output.done) throw Error('TMP stream ended early'); + if (output.done) { + throw new rpcErrors.ErrorRpcRemoteError('Stream ended before response'); + } await callerInterface.end(); await callerInterface.close(); return output.value; diff --git a/src/RPC/RPCServer.ts b/src/RPC/RPCServer.ts index 1a23569a4..088eb04f6 100644 --- a/src/RPC/RPCServer.ts +++ b/src/RPC/RPCServer.ts @@ -43,6 +43,7 @@ class RPCServer { protected handlerMap: Map> = new Map(); private activeStreams: Set> = new Set(); + private events: EventTarget = new EventTarget(); public constructor({ container, @@ -137,15 +138,15 @@ class RPCServer { // Constructing the PromiseCancellable for tracking the active stream let resolve: (value: void | PromiseLike) => void; const abortController = new AbortController(); - const handlerProm2: PromiseCancellable = new PromiseCancellable( + const handlerProm: PromiseCancellable = new PromiseCancellable( (resolve_) => { resolve = resolve_; }, abortController, ); // Putting the PromiseCancellable into the active streams map - this.activeStreams.add(handlerProm2); - void handlerProm2.finally(() => this.activeStreams.delete(handlerProm2)); + this.activeStreams.add(handlerProm); + void handlerProm.finally(() => this.activeStreams.delete(handlerProm)); // While ReadableStream can be converted to AsyncIterable, we want it as // a generator. const inputGen = async function* () { @@ -167,7 +168,7 @@ class RPCServer { if (ctx.signal.aborted) throw ctx.signal.reason; const leadingMetadataMessage = await input.next(); if (leadingMetadataMessage.done === true) { - throw Error('TMP Stream closed early'); + throw new rpcErrors.ErrorRpcProtocal('Stream ended before response'); } const method = leadingMetadataMessage.value.method; const initialParams = leadingMetadataMessage.value.params; diff --git a/src/RPC/errors.ts b/src/RPC/errors.ts index 166706771..e47722205 100644 --- a/src/RPC/errors.ts +++ b/src/RPC/errors.ts @@ -42,6 +42,11 @@ class ErrorRpcMessageLength extends ErrorRpc { exitCode = sysexits.DATAERR; } +class ErrorRpcRemoteError extends ErrorRpc { + static description = 'RPC Message exceeds maximum size'; + exitCode = sysexits.UNAVAILABLE; +} + export { ErrorRpc, ErrorRpcRunning, @@ -52,4 +57,5 @@ export { ErrorRpcHandlerMissing, ErrorRpcProtocal, ErrorRpcMessageLength, + ErrorRpcRemoteError, }; diff --git a/src/RPC/utils.ts b/src/RPC/utils.ts index 8881b40af..d8e949fcd 100644 --- a/src/RPC/utils.ts +++ b/src/RPC/utils.ts @@ -15,9 +15,11 @@ import type { } from 'RPC/types'; import type { JSONValue } from '../types'; import { TransformStream } from 'stream/web'; +import { AbstractError } from '@matrixai/errors'; import * as rpcErrors from './errors'; import * as utils from '../utils'; import * as validationErrors from '../validation/errors'; +import * as errors from '../errors'; const jsonStreamParsers = require('@streamparser/json'); class JsonToJsonMessage @@ -313,7 +315,7 @@ function sensitiveReplacer(key: string, value: any) { } /** - * Serializes Error instances into GRPC errors + * Serializes Error instances into RPC errors * Use this on the sending side to send exceptions * Do not send exceptions to clients you do not trust * If sending to an agent (rather than a client), set sensitive to true to @@ -321,12 +323,108 @@ function sensitiveReplacer(key: string, value: any) { */ function fromError(error: Error, sensitive: boolean = false) { if (sensitive) { - return { error: JSON.stringify(error, sensitiveReplacer) }; + return JSON.stringify(error, sensitiveReplacer); } else { - return { error: JSON.stringify(error, replacer) }; + return JSON.stringify(error, replacer); } } +/** + * Error constructors for non-Polykey errors + * Allows these errors to be reconstructed from GRPC metadata + */ +const standardErrors = { + Error, + TypeError, + SyntaxError, + ReferenceError, + EvalError, + RangeError, + URIError, + AggregateError, + AbstractError, +}; + +/** + * Reviver function for deserialising errors sent over GRPC (used by + * `JSON.parse` in `toError`) + * The final result returned will always be an error - if the deserialised + * data is of an unknown type then this will be wrapped as an + * `ErrorPolykeyUnknown` + */ +function reviver(key: string, value: any): any { + // If the value is an error then reconstruct it + if ( + typeof value === 'object' && + typeof value.type === 'string' && + typeof value.data === 'object' + ) { + try { + let eClass = errors[value.type]; + if (eClass != null) return eClass.fromJSON(value); + eClass = standardErrors[value.type]; + if (eClass != null) { + let e; + switch (eClass) { + case AbstractError: + return eClass.fromJSON(); + case AggregateError: + if ( + !Array.isArray(value.data.errors) || + typeof value.data.message !== 'string' || + ('stack' in value.data && typeof value.data.stack !== 'string') + ) { + throw new TypeError(`cannot decode JSON to ${value.type}`); + } + e = new eClass(value.data.errors, value.data.message); + e.stack = value.data.stack; + break; + default: + if ( + typeof value.data.message !== 'string' || + ('stack' in value.data && typeof value.data.stack !== 'string') + ) { + throw new TypeError(`Cannot decode JSON to ${value.type}`); + } + e = new eClass(value.data.message); + e.stack = value.data.stack; + break; + } + return e; + } + } catch (e) { + // If `TypeError` which represents decoding failure + // then return value as-is + // Any other exception is a bug + if (!(e instanceof TypeError)) { + throw e; + } + } + // Other values are returned as-is + return value; + } else if (key === '') { + // Root key will be '' + // Reaching here means the root JSON value is not a valid exception + // Therefore ErrorPolykeyUnknown is only ever returned at the top-level + const error = new errors.ErrorPolykeyUnknown('Unknown error JSON', { + data: { + json: value, + }, + }); + return error; + } else { + return value; + } +} + +function toError(errorData) { + if (errorData == null) return new rpcErrors.ErrorRpcRemoteError(); + const error = JSON.parse(errorData, reviver); + return new rpcErrors.ErrorRpcRemoteError(error.message, { + cause: error, + }); +} + export { JsonToJsonMessageStream, JsonMessageToJsonStream, @@ -338,4 +436,5 @@ export { parseJsonRpcResponse, parseJsonRpcMessage, fromError, + toError, }; diff --git a/tests/RPC/RPCClient.test.ts b/tests/RPC/RPCClient.test.ts index 8cfc69732..98f05d955 100644 --- a/tests/RPC/RPCClient.test.ts +++ b/tests/RPC/RPCClient.test.ts @@ -5,6 +5,7 @@ import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; import { testProp, fc } from '@fast-check/jest'; import RPCClient from '@/RPC/RPCClient'; import RPCServer from '@/RPC/RPCServer'; +import * as rpcErrors from '@/RPC/errors'; import * as rpcTestUtils from './utils'; describe(`${RPCClient.name}`, () => { @@ -157,4 +158,41 @@ describe(`${RPCClient.name}`, () => { ); }, ); + + testProp.only( + 'generic duplex caller can throw received error message', + [ + fc.array(rpcTestUtils.jsonRpcResponseResultArb()), + rpcTestUtils.jsonRpcResponseErrorArb(), + ], + async (messages, errorMessage) => { + const inputStream = rpcTestUtils.jsonRpcStream([ + ...messages, + errorMessage, + ]); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const streamPair: ReadableWritablePair = { + readable: inputStream, + writable: outputStream, + }; + const rpcClient = await RPCClient.createRPCClient({ + streamPairCreateCallback: async () => streamPair, + logger, + }); + const callerInterface = await rpcClient.duplexStreamCaller< + JSONValue, + JSONValue + >(methodName, { hello: 'world' }); + const consumeToError = async () => { + for await (const _ of callerInterface.outputGenerator) { + // No touch, just consume + } + }; + await expect(consumeToError()).rejects.toThrow( + rpcErrors.ErrorRpcRemoteError, + ); + await callerInterface.end(); + await outputResult; + }, + ); }); diff --git a/tests/RPC/utils.ts b/tests/RPC/utils.ts index 762084d5f..b11ae020e 100644 --- a/tests/RPC/utils.ts +++ b/tests/RPC/utils.ts @@ -19,6 +19,7 @@ import type { JsonValue } from 'fast-check'; import { ReadableStream, WritableStream, TransformStream } from 'stream/web'; import { fc } from '@fast-check/jest'; import * as utils from '@/utils'; +import { fromError } from '@/RPC/utils'; class BufferStreamToSnipped implements Transformer { protected buffer = Buffer.alloc(0); @@ -107,6 +108,8 @@ const safeJsonValueArb = fc .jsonValue() .map((value) => JSON.parse(JSON.stringify(value))); +const idArb = fc.oneof(fc.string(), fc.integer(), fc.constant(null)); + const jsonRpcRequestMessageArb = ( method: fc.Arbitrary = fc.string(), params: fc.Arbitrary = safeJsonValueArb, @@ -114,14 +117,13 @@ const jsonRpcRequestMessageArb = ( fc .record( { - type: fc.constant('JsonRpcRequest'), jsonrpc: fc.constant('2.0'), method: method, params: params, - id: fc.oneof(fc.string(), fc.integer(), fc.constant(null)), + id: idArb, }, { - requiredKeys: ['type', 'jsonrpc', 'method', 'id'], + requiredKeys: ['jsonrpc', 'method', 'id'], }, ) .noShrink() as fc.Arbitrary; @@ -133,13 +135,12 @@ const jsonRpcRequestNotificationArb = ( fc .record( { - type: fc.constant('JsonRpcNotification'), jsonrpc: fc.constant('2.0'), method: method, params: params, }, { - requiredKeys: ['type', 'jsonrpc', 'method'], + requiredKeys: ['jsonrpc', 'method'], }, ) .noShrink() as fc.Arbitrary; @@ -160,40 +161,40 @@ const jsonRpcResponseResultArb = ( ) => fc .record({ - type: fc.constant('JsonRpcResponseResult'), jsonrpc: fc.constant('2.0'), result: result, - id: fc.oneof(fc.string(), fc.integer(), fc.constant(null)), + id: idArb, }) .noShrink() as fc.Arbitrary; -const jsonRpcErrorArb = fc - .record( - { - code: fc.integer(), - message: fc.string(), - data: safeJsonValueArb, - }, - { - requiredKeys: ['code', 'message'], - }, - ) - .noShrink() as fc.Arbitrary; +const jsonRpcErrorArb = (error: Error = new Error('test error')) => + fc + .record( + { + code: fc.integer(), + message: fc.string(), + data: fc.constant(fromError(error)), + }, + { + requiredKeys: ['code', 'message'], + }, + ) + .noShrink() as fc.Arbitrary; -const jsonRpcResponseErrorArb = fc - .record({ - type: fc.constant('JsonRpcResponseError'), - jsonrpc: fc.constant('2.0'), - error: jsonRpcErrorArb, - id: fc.oneof(fc.string(), fc.integer(), fc.constant(null)), - }) - .noShrink() as fc.Arbitrary; +const jsonRpcResponseErrorArb = (error?: Error) => + fc + .record({ + jsonrpc: fc.constant('2.0'), + error: jsonRpcErrorArb(error), + id: idArb, + }) + .noShrink() as fc.Arbitrary; const jsonRpcResponseArb = ( result: fc.Arbitrary = safeJsonValueArb, ) => fc - .oneof(jsonRpcResponseResultArb(result), jsonRpcResponseErrorArb) + .oneof(jsonRpcResponseResultArb(result), jsonRpcResponseErrorArb()) .noShrink() as fc.Arbitrary; const jsonRpcMessageArb = ( From 7cb904008d322644fe08dd17878cef5288a1cade Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Fri, 20 Jan 2023 15:16:32 +1100 Subject: [PATCH 17/44] fix: small bug with uncaught promise Related #501 Related #502 [ci skip] --- src/RPC/RPCClient.ts | 12 ++++---- src/RPC/RPCServer.ts | 13 ++++++--- tests/RPC/RPC.test.ts | 39 +++++++++++++++---------- tests/RPC/RPCClient.test.ts | 7 ++++- tests/RPC/RPCServer.test.ts | 58 +++++++++++++++++++++---------------- 5 files changed, 77 insertions(+), 52 deletions(-) diff --git a/src/RPC/RPCClient.ts b/src/RPC/RPCClient.ts index 01a12631b..bf916a21b 100644 --- a/src/RPC/RPCClient.ts +++ b/src/RPC/RPCClient.ts @@ -67,7 +67,7 @@ class RPCClient { const inputFinishedProm = promise(); const outputFinishedProm = promise(); const abortController = new AbortController(); - const handlerProm: PromiseCancellable = new PromiseCancellable( + const handlerProm: PromiseCancellable = new PromiseCancellable( (resolve) => { Promise.all([inputFinishedProm.p, outputFinishedProm.p]).finally(() => resolve(), @@ -77,14 +77,16 @@ class RPCClient { ); // Putting the PromiseCancellable into the active streams map this.activeStreams.add(handlerProm); - void handlerProm.finally(() => this.activeStreams.delete(handlerProm)); + void handlerProm + .finally(() => this.activeStreams.delete(handlerProm)) + .catch(() => {}); const streamPair = await this.streamPairCreateCallback(); const inputStream = streamPair.readable.pipeThrough( new rpcUtils.JsonToJsonMessageStream(rpcUtils.parseJsonRpcResponse), ); const outputTransform = new rpcUtils.JsonMessageToJsonStream(); - void outputTransform.readable.pipeTo(streamPair.writable); + void outputTransform.readable.pipeTo(streamPair.writable).catch(() => {}); const inputGen = async function* (): AsyncGenerator { const writer = outputTransform.writable.getWriter(); @@ -100,8 +102,6 @@ class RPCClient { }; await writer.write(message); } - } catch (e) { - await writer.abort(e); } finally { await writer.close(); inputFinishedProm.resolveP(); @@ -144,7 +144,7 @@ class RPCClient { await output.return(); }, throw: async (reason: any) => { - await input.throw(reason); + await input.return(); await output.throw(reason); }, }; diff --git a/src/RPC/RPCServer.ts b/src/RPC/RPCServer.ts index 088eb04f6..5d83d3134 100644 --- a/src/RPC/RPCServer.ts +++ b/src/RPC/RPCServer.ts @@ -59,10 +59,12 @@ class RPCServer { public async destroy(): Promise { this.logger.info(`Destroying ${this.constructor.name}`); // Stopping any active steams - const activeStreams = this.activeStreams; - for await (const [activeStream] of activeStreams.entries()) { + for await (const [activeStream] of this.activeStreams.entries()) { activeStream.cancel(new rpcErrors.ErrorRpcStopping()); } + for await (const [activeStream] of this.activeStreams.entries()) { + await activeStream; + } this.logger.info(`Destroyed ${this.constructor.name}`); } @@ -146,7 +148,9 @@ class RPCServer { ); // Putting the PromiseCancellable into the active streams map this.activeStreams.add(handlerProm); - void handlerProm.finally(() => this.activeStreams.delete(handlerProm)); + void handlerProm + .finally(() => this.activeStreams.delete(handlerProm)) + .catch(() => {}); // While ReadableStream can be converted to AsyncIterable, we want it as // a generator. const inputGen = async function* () { @@ -238,7 +242,8 @@ class RPCServer { }); void outputStream .pipeThrough(new rpcUtils.JsonMessageToJsonStream()) - .pipeTo(streamPair.writable); + .pipeTo(streamPair.writable) + .catch(() => {}); } } diff --git a/tests/RPC/RPC.test.ts b/tests/RPC/RPC.test.ts index c7516001d..2a977ee67 100644 --- a/tests/RPC/RPC.test.ts +++ b/tests/RPC/RPC.test.ts @@ -27,7 +27,7 @@ describe('RPC', () => { >(); const container = {}; - const rpc = await RPCServer.createRPCServer({ container, logger }); + const rpcServer = await RPCServer.createRPCServer({ container, logger }); const duplexHandler: DuplexStreamHandler = async function* (input, _container, _connectionInfo, _ctx) { @@ -36,8 +36,8 @@ describe('RPC', () => { } }; - rpc.registerDuplexStreamHandler(methodName, duplexHandler); - rpc.handleStream(serverPair, {} as ConnectionInfo); + rpcServer.registerDuplexStreamHandler(methodName, duplexHandler); + rpcServer.handleStream(serverPair, {} as ConnectionInfo); const rpcClient = await RPCClient.createRPCClient({ streamPairCreateCallback: async () => clientPair, @@ -55,11 +55,13 @@ describe('RPC', () => { await callerInterface.end(); expect((await callerInterface.read()).value).toBeUndefined(); expect((await callerInterface.read()).done).toBeTrue(); + await rpcServer.destroy(); + await rpcClient.destroy(); }, ); testProp( - 'RPC communication with client stream', + 'RPC communication with server stream', [fc.integer({ min: 1, max: 100 })], async (value) => { const { clientPair, serverPair } = rpcTestUtils.createTapPairs< @@ -68,7 +70,7 @@ describe('RPC', () => { >(); const container = {}; - const rpc = await RPCServer.createRPCServer({ container, logger }); + const rpcServer = await RPCServer.createRPCServer({ container, logger }); const serverStreamHandler: ServerStreamHandler = async function* (input, _container, _connectionInfo, _ctx) { @@ -77,8 +79,8 @@ describe('RPC', () => { } }; - rpc.registerServerStreamHandler(methodName, serverStreamHandler); - rpc.handleStream(serverPair, {} as ConnectionInfo); + rpcServer.registerServerStreamHandler(methodName, serverStreamHandler); + rpcServer.handleStream(serverPair, {} as ConnectionInfo); const rpcClient = await RPCClient.createRPCClient({ streamPairCreateCallback: async () => clientPair, @@ -95,13 +97,14 @@ describe('RPC', () => { outputs.push(num); } expect(outputs.length).toEqual(value); + await rpcServer.destroy(); + await rpcClient.destroy(); }, - { numRuns: 1 }, ); testProp( - 'RPC communication with server stream', - [fc.array(fc.integer(), { minLength: 1 })], + 'RPC communication with client stream', + [fc.array(fc.integer(), { minLength: 1 }).noShrink()], async (values) => { const { clientPair, serverPair } = rpcTestUtils.createTapPairs< Buffer, @@ -109,7 +112,7 @@ describe('RPC', () => { >(); const container = {}; - const rpc = await RPCServer.createRPCServer({ container, logger }); + const rpcServer = await RPCServer.createRPCServer({ container, logger }); const clientStreamhandler: ClientStreamHandler = async ( input, @@ -120,8 +123,8 @@ describe('RPC', () => { } return acc; }; - rpc.registerClientStreamHandler(methodName, clientStreamhandler); - rpc.handleStream(serverPair, {} as ConnectionInfo); + rpcServer.registerClientStreamHandler(methodName, clientStreamhandler); + rpcServer.handleStream(serverPair, {} as ConnectionInfo); const rpcClient = await RPCClient.createRPCClient({ streamPairCreateCallback: async () => clientPair, @@ -139,6 +142,8 @@ describe('RPC', () => { const expectedResult = values.reduce((p, c) => p + c); await expect(callerInterface.result).resolves.toEqual(expectedResult); + await rpcServer.destroy(); + await rpcClient.destroy(); }, ); @@ -152,12 +157,12 @@ describe('RPC', () => { >(); const container = {}; - const rpc = await RPCServer.createRPCServer({ container, logger }); + const rpcServer = await RPCServer.createRPCServer({ container, logger }); const unaryCaller: UnaryHandler = async (input) => input; - rpc.registerUnaryHandler(methodName, unaryCaller); - rpc.handleStream(serverPair, {} as ConnectionInfo); + rpcServer.registerUnaryHandler(methodName, unaryCaller); + rpcServer.handleStream(serverPair, {} as ConnectionInfo); const rpcClient = await RPCClient.createRPCClient({ streamPairCreateCallback: async () => clientPair, @@ -170,6 +175,8 @@ describe('RPC', () => { {}, ); expect(result).toStrictEqual(value); + await rpcServer.destroy(); + await rpcClient.destroy(); }, ); }); diff --git a/tests/RPC/RPCClient.test.ts b/tests/RPC/RPCClient.test.ts index 98f05d955..25d2215d1 100644 --- a/tests/RPC/RPCClient.test.ts +++ b/tests/RPC/RPCClient.test.ts @@ -57,6 +57,7 @@ describe(`${RPCClient.name}`, () => { JSON.parse(v.toString()), ); expect(outputMessages).toStrictEqual(expectedMessages); + await rpcClient.destroy(); }); testProp( 'generic server stream caller', @@ -90,6 +91,7 @@ describe(`${RPCClient.name}`, () => { params, }), ); + await rpcClient.destroy(); }, ); testProp( @@ -126,6 +128,7 @@ describe(`${RPCClient.name}`, () => { expect((await outputResult).map((v) => v.toString())).toStrictEqual( expectedOutput, ); + await rpcClient.destroy(); }, ); testProp( @@ -156,10 +159,11 @@ describe(`${RPCClient.name}`, () => { params: params, }), ); + await rpcClient.destroy(); }, ); - testProp.only( + testProp( 'generic duplex caller can throw received error message', [ fc.array(rpcTestUtils.jsonRpcResponseResultArb()), @@ -193,6 +197,7 @@ describe(`${RPCClient.name}`, () => { ); await callerInterface.end(); await outputResult; + await rpcClient.destroy(); }, ); }); diff --git a/tests/RPC/RPCServer.test.ts b/tests/RPC/RPCServer.test.ts index 0df1a8497..86e224afa 100644 --- a/tests/RPC/RPCServer.test.ts +++ b/tests/RPC/RPCServer.test.ts @@ -32,7 +32,7 @@ describe(`${RPCServer.name}`, () => { async (messages) => { const stream = rpcTestUtils.jsonRpcStream(messages); const container = {}; - const rpc = await RPCServer.createRPCServer({ container, logger }); + const rpcServer = await RPCServer.createRPCServer({ container, logger }); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); const readWriteStream: ReadableWritablePair = { readable: stream, @@ -47,9 +47,10 @@ describe(`${RPCServer.name}`, () => { } }; - rpc.registerDuplexStreamHandler(methodName, duplexHandler); - rpc.handleStream(readWriteStream, {} as ConnectionInfo); + rpcServer.registerDuplexStreamHandler(methodName, duplexHandler); + rpcServer.handleStream(readWriteStream, {} as ConnectionInfo); await outputResult; + await rpcServer.destroy(); }, ); @@ -59,7 +60,7 @@ describe(`${RPCServer.name}`, () => { async (messages) => { const stream = rpcTestUtils.jsonRpcStream(messages); const container = {}; - const rpc = await RPCServer.createRPCServer({ container, logger }); + const rpcServer = await RPCServer.createRPCServer({ container, logger }); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); const readWriteStream: ReadableWritablePair = { readable: stream, @@ -75,9 +76,10 @@ describe(`${RPCServer.name}`, () => { return count; }; - rpc.registerClientStreamHandler(methodName, clientHandler); - rpc.handleStream(readWriteStream, {} as ConnectionInfo); + rpcServer.registerClientStreamHandler(methodName, clientHandler); + rpcServer.handleStream(readWriteStream, {} as ConnectionInfo); await outputResult; + await rpcServer.destroy(); }, ); @@ -98,7 +100,7 @@ describe(`${RPCServer.name}`, () => { async (messages) => { const stream = rpcTestUtils.jsonRpcStream(messages); const container = {}; - const rpc = await RPCServer.createRPCServer({ container, logger }); + const rpcServer = await RPCServer.createRPCServer({ container, logger }); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); const readWriteStream: ReadableWritablePair = { readable: stream, @@ -112,9 +114,10 @@ describe(`${RPCServer.name}`, () => { } }; - rpc.registerServerStreamHandler(methodName, serverHandler); - rpc.handleStream(readWriteStream, {} as ConnectionInfo); + rpcServer.registerServerStreamHandler(methodName, serverHandler); + rpcServer.handleStream(readWriteStream, {} as ConnectionInfo); await outputResult; + await rpcServer.destroy(); }, ); @@ -124,7 +127,7 @@ describe(`${RPCServer.name}`, () => { async (messages) => { const stream = rpcTestUtils.jsonRpcStream(messages); const container = {}; - const rpc = await RPCServer.createRPCServer({ container, logger }); + const rpcServer = await RPCServer.createRPCServer({ container, logger }); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); const readWriteStream: ReadableWritablePair = { readable: stream, @@ -140,9 +143,10 @@ describe(`${RPCServer.name}`, () => { return input; }; - rpc.registerUnaryHandler(methodName, unaryHandler); - rpc.handleStream(readWriteStream, {} as ConnectionInfo); + rpcServer.registerUnaryHandler(methodName, unaryHandler); + rpcServer.handleStream(readWriteStream, {} as ConnectionInfo); await outputResult; + await rpcServer.destroy(); }, ); @@ -156,7 +160,7 @@ describe(`${RPCServer.name}`, () => { B: Symbol('b'), C: Symbol('c'), }; - const rpc = await RPCServer.createRPCServer({ container, logger }); + const rpcServer = await RPCServer.createRPCServer({ container, logger }); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); const readWriteStream: ReadableWritablePair = { readable: stream, @@ -171,9 +175,10 @@ describe(`${RPCServer.name}`, () => { } }; - rpc.registerDuplexStreamHandler(methodName, duplexHandler); - rpc.handleStream(readWriteStream, {} as ConnectionInfo); + rpcServer.registerDuplexStreamHandler(methodName, duplexHandler); + rpcServer.handleStream(readWriteStream, {} as ConnectionInfo); await outputResult; + await rpcServer.destroy(); }, ); @@ -191,7 +196,7 @@ describe(`${RPCServer.name}`, () => { remotePort: 12341 as Port, }; const container = {}; - const rpc = await RPCServer.createRPCServer({ container, logger }); + const rpcServer = await RPCServer.createRPCServer({ container, logger }); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); const readWriteStream: ReadableWritablePair = { readable: stream, @@ -205,9 +210,10 @@ describe(`${RPCServer.name}`, () => { yield val; } }; - rpc.registerDuplexStreamHandler(methodName, duplexHandler); - rpc.handleStream(readWriteStream, {} as ConnectionInfo); + rpcServer.registerDuplexStreamHandler(methodName, duplexHandler); + rpcServer.handleStream(readWriteStream, {} as ConnectionInfo); await outputResult; + await rpcServer.destroy(); }, ); @@ -219,7 +225,7 @@ describe(`${RPCServer.name}`, () => { async (messages) => { const stream = rpcTestUtils.jsonRpcStream(messages); const container = {}; - const rpc = await RPCServer.createRPCServer({ container, logger }); + const rpcServer = await RPCServer.createRPCServer({ container, logger }); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); let thing; let lastMessage: JsonRpcMessage | undefined; @@ -228,7 +234,7 @@ describe(`${RPCServer.name}`, () => { // async (_, iteration) => { // if (iteration === 2) { // // @ts-ignore: kidnap private property - // const activeStreams = rpc.activeStreams.values(); + // const activeStreams = rpcServer.activeStreams.values(); // for (const activeStream of activeStreams) { // thing = activeStream; // activeStream.cancel(new rpcErrors.ErrorRpcStopping()); @@ -249,19 +255,20 @@ describe(`${RPCServer.name}`, () => { yield val; } }; - rpc.registerDuplexStreamHandler(methodName, duplexHandler); - rpc.handleStream(readWriteStream, {} as ConnectionInfo); + rpcServer.registerDuplexStreamHandler(methodName, duplexHandler); + rpcServer.handleStream(readWriteStream, {} as ConnectionInfo); await outputResult; await expect(thing).toResolve(); // Last message should be an error message expect(lastMessage).toBeDefined(); + await rpcServer.destroy(); }, ); testProp('Handler yields nothing', [specificMessageArb], async (messages) => { const stream = rpcTestUtils.jsonRpcStream(messages); const container = {}; - const rpc = await RPCServer.createRPCServer({ container, logger }); + const rpcServer = await RPCServer.createRPCServer({ container, logger }); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); const readWriteStream: ReadableWritablePair = { readable: stream, @@ -275,10 +282,11 @@ describe(`${RPCServer.name}`, () => { } }; - rpc.registerDuplexStreamHandler(methodName, duplexHandler); - rpc.handleStream(readWriteStream, {} as ConnectionInfo); + rpcServer.registerDuplexStreamHandler(methodName, duplexHandler); + rpcServer.handleStream(readWriteStream, {} as ConnectionInfo); await outputResult; // We're just expecting no errors + await rpcServer.destroy(); }); // TODO: From c0c99200772c1dd7d783c637acbd1667a085ec2e Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Fri, 20 Jan 2023 18:13:48 +1100 Subject: [PATCH 18/44] fix: changed generic callers to use a ReadableWritablePair as the interface Related #501 [ci skip] --- src/RPC/RPCClient.ts | 169 ++++++++---------------------------- src/RPC/types.ts | 37 -------- src/RPC/utils.ts | 52 +++++++++++ tests/RPC/RPC.test.ts | 22 +++-- tests/RPC/RPCClient.test.ts | 22 ++--- 5 files changed, 115 insertions(+), 187 deletions(-) diff --git a/src/RPC/RPCClient.ts b/src/RPC/RPCClient.ts index bf916a21b..bc7ea8bfd 100644 --- a/src/RPC/RPCClient.ts +++ b/src/RPC/RPCClient.ts @@ -1,17 +1,10 @@ -import type { - ClientCallerInterface, - DuplexCallerInterface, - JsonRpcRequestMessage, - ServerCallerInterface, - StreamPairCreateCallback, -} from './types'; +import type { StreamPairCreateCallback } from './types'; import type { JSONValue, POJO } from 'types'; -import { PromiseCancellable } from '@matrixai/async-cancellable'; +import type { ReadableWritablePair } from 'stream/web'; import { CreateDestroy, ready } from '@matrixai/async-init/dist/CreateDestroy'; import Logger from '@matrixai/logger'; import * as rpcErrors from './errors'; import * as rpcUtils from './utils'; -import { promise } from '../utils/index'; interface RPCClient extends CreateDestroy {} @CreateDestroy() @@ -33,7 +26,6 @@ class RPCClient { } protected logger: Logger; - protected activeStreams: Set> = new Set(); protected streamPairCreateCallback: StreamPairCreateCallback; public constructor({ @@ -49,12 +41,6 @@ class RPCClient { public async destroy(): Promise { this.logger.info(`Destroying ${this.constructor.name}`); - for await (const [stream] of this.activeStreams.entries()) { - stream.cancel(new rpcErrors.ErrorRpcStopping()); - } - for await (const [stream] of this.activeStreams.entries()) { - await stream; - } this.logger.info(`Destroyed ${this.constructor.name}`); } @@ -62,91 +48,25 @@ class RPCClient { public async duplexStreamCaller( method: string, _metadata: POJO, - ): Promise> { - // Constructing the PromiseCancellable for tracking the active stream - const inputFinishedProm = promise(); - const outputFinishedProm = promise(); - const abortController = new AbortController(); - const handlerProm: PromiseCancellable = new PromiseCancellable( - (resolve) => { - Promise.all([inputFinishedProm.p, outputFinishedProm.p]).finally(() => - resolve(), - ); - }, - abortController, - ); - // Putting the PromiseCancellable into the active streams map - this.activeStreams.add(handlerProm); - void handlerProm - .finally(() => this.activeStreams.delete(handlerProm)) - .catch(() => {}); - + ): Promise> { const streamPair = await this.streamPairCreateCallback(); - const inputStream = streamPair.readable.pipeThrough( - new rpcUtils.JsonToJsonMessageStream(rpcUtils.parseJsonRpcResponse), - ); - const outputTransform = new rpcUtils.JsonMessageToJsonStream(); - void outputTransform.readable.pipeTo(streamPair.writable).catch(() => {}); - - const inputGen = async function* (): AsyncGenerator { - const writer = outputTransform.writable.getWriter(); - let value: I; - try { - while (true) { - value = yield; - const message: JsonRpcRequestMessage = { - method, - jsonrpc: '2.0', - id: null, - params: value, - }; - await writer.write(message); - } - } finally { - await writer.close(); - inputFinishedProm.resolveP(); - } - }; + const outputStream = streamPair.readable + .pipeThrough( + new rpcUtils.JsonToJsonMessageStream(rpcUtils.parseJsonRpcResponse), + ) + .pipeThrough(new rpcUtils.ClientOutputTransformerStream()); + const inputMessageTransformer = + new rpcUtils.ClientInputTransformerStream(method); + void inputMessageTransformer.readable + .pipeThrough(new rpcUtils.JsonMessageToJsonStream()) + .pipeTo(streamPair.writable) + .catch(() => {}); + const inputStream = inputMessageTransformer.writable; - const outputGen = async function* (): AsyncGenerator { - try { - for await (const result of inputStream) { - if ('error' in result) { - throw rpcUtils.toError(result.error.data); - } - yield result.result as O; - } - } finally { - outputFinishedProm.resolveP(); - } - }; - const output = outputGen(); - const input = inputGen(); - // Initiating the input generator - await input.next(); - // Hooking up abort signals - abortController.signal.addEventListener('abort', async () => { - await output.throw(abortController.signal.reason); - await input.throw(abortController.signal.reason); - }); // Returning interface return { - read: () => output.next(), - write: async (value: I) => { - await input.next(value); - }, - inputGenerator: input, - outputGenerator: output, - end: async () => { - await input.return(); - }, - close: async () => { - await output.return(); - }, - throw: async (reason: any) => { - await input.return(); - await output.throw(reason); - }, + readable: outputStream, + writable: inputStream, }; } @@ -155,52 +75,37 @@ class RPCClient { method: string, parameters: I, metadata: POJO, - ): Promise> { + ) { const callerInterface = await this.duplexStreamCaller( method, metadata, ); - await callerInterface.write(parameters); - await callerInterface.end(); + const writer = callerInterface.writable.getWriter(); + await writer.write(parameters); + await writer.close(); - return { - read: () => callerInterface.read(), - outputGenerator: callerInterface.outputGenerator, - close: () => callerInterface.close(), - throw: async (reason: any) => { - await callerInterface.outputGenerator.throw(reason); - }, - }; + return callerInterface.readable; } @ready(new rpcErrors.ErrorRpcDestroyed()) public async clientStreamCaller( method: string, metadata: POJO, - ): Promise> { + ) { const callerInterface = await this.duplexStreamCaller( method, metadata, ); - const output = callerInterface - .read() - .then(({ value, done }) => { - if (done) { - throw new rpcErrors.ErrorRpcRemoteError( - 'Stream ended before response', - ); - } - return value; - }) - .finally(async () => { - await callerInterface.close(); - }); + const reader = callerInterface.readable.getReader(); + const output = reader.read().then(({ value, done }) => { + if (done) { + throw new rpcErrors.ErrorRpcRemoteError('Stream ended before response'); + } + return value; + }); return { - write: (value: I) => callerInterface.write(value), - result: output, - inputGenerator: callerInterface.inputGenerator, - end: () => callerInterface.end(), - throw: (reason: any) => callerInterface.throw(reason), + output, + writable: callerInterface.writable, }; } @@ -214,13 +119,15 @@ class RPCClient { method, metadata, ); - await callerInterface.write(parameters); - const output = await callerInterface.read(); + const reader = callerInterface.readable.getReader(); + const writer = callerInterface.writable.getWriter(); + await writer.write(parameters); + const output = await reader.read(); if (output.done) { throw new rpcErrors.ErrorRpcRemoteError('Stream ended before response'); } - await callerInterface.end(); - await callerInterface.close(); + await reader.cancel(); + await writer.close(); return output.value; } } diff --git a/src/RPC/types.ts b/src/RPC/types.ts index 3e1cfaa66..884709b68 100644 --- a/src/RPC/types.ts +++ b/src/RPC/types.ts @@ -123,40 +123,6 @@ type UnaryHandler = Handler< Promise >; -/** - * @property read Read from the output generator - * @property write Write to the input generator - * @property inputGenerator Low level access to the input generator - * @property outputGenerator Low level access to the output generator - * @property end Signal end to the input generator - * @property close Signal early close to the output generator - * @property throw Throw to both generators - */ -type DuplexCallerInterface = { - read: () => Promise>; - write: (value: I) => Promise; - inputGenerator: AsyncGenerator; - outputGenerator: AsyncGenerator; - end: () => Promise; - close: () => Promise; - throw: (reason: any) => Promise; -}; - -type ServerCallerInterface = { - read: () => Promise>; - outputGenerator: AsyncGenerator; - close: () => Promise; - throw: (reason: any) => Promise; -}; - -type ClientCallerInterface = { - write: (value: I) => Promise; - result: Promise; - inputGenerator: AsyncGenerator; - end: () => Promise; - throw: (reason: any) => Promise; -}; - type StreamPairCreateCallback = () => Promise< ReadableWritablePair >; @@ -174,8 +140,5 @@ export type { ServerStreamHandler, ClientStreamHandler, UnaryHandler, - DuplexCallerInterface, - ServerCallerInterface, - ClientCallerInterface, StreamPairCreateCallback, }; diff --git a/src/RPC/utils.ts b/src/RPC/utils.ts index d8e949fcd..3c6acdcd6 100644 --- a/src/RPC/utils.ts +++ b/src/RPC/utils.ts @@ -425,6 +425,56 @@ function toError(errorData) { }); } +class ClientInputTransformer + implements Transformer +{ + constructor(protected method: string) {} + + transform: TransformerTransformCallback> = async ( + chunk, + controller, + ) => { + const message: JsonRpcRequestMessage = { + method: this.method, + jsonrpc: '2.0', + id: null, + params: chunk, + }; + controller.enqueue(message); + }; +} + +class ClientInputTransformerStream extends TransformStream< + I, + JsonRpcRequestMessage +> { + constructor(method: string) { + super(new ClientInputTransformer(method)); + } +} + +class ClientOutputTransformer + implements Transformer, O> +{ + transform: TransformerTransformCallback, O> = async ( + chunk, + controller, + ) => { + if ('error' in chunk) { + throw toError(chunk.error.data); + } + controller.enqueue(chunk.result); + }; +} + +class ClientOutputTransformerStream< + O extends JSONValue, +> extends TransformStream, O> { + constructor() { + super(new ClientOutputTransformer()); + } +} + export { JsonToJsonMessageStream, JsonMessageToJsonStream, @@ -437,4 +487,6 @@ export { parseJsonRpcMessage, fromError, toError, + ClientInputTransformerStream, + ClientOutputTransformerStream, }; diff --git a/tests/RPC/RPC.test.ts b/tests/RPC/RPC.test.ts index 2a977ee67..48edbb5e5 100644 --- a/tests/RPC/RPC.test.ts +++ b/tests/RPC/RPC.test.ts @@ -48,13 +48,16 @@ describe('RPC', () => { methodName, {}, ); + const writer = callerInterface.writable.getWriter(); + const reader = callerInterface.readable.getReader(); for (const value of values) { - await callerInterface.write(value); - expect((await callerInterface.read()).value).toStrictEqual(value); + await writer.write(value); + expect((await reader.read()).value).toStrictEqual(value); } - await callerInterface.end(); - expect((await callerInterface.read()).value).toBeUndefined(); - expect((await callerInterface.read()).done).toBeTrue(); + await writer.close(); + const result = await reader.read(); + expect(result.value).toBeUndefined(); + expect(result.done).toBeTrue(); await rpcServer.destroy(); await rpcClient.destroy(); }, @@ -93,7 +96,7 @@ describe('RPC', () => { >(methodName, value, {}); const outputs: Array = []; - for await (const num of callerInterface.outputGenerator) { + for await (const num of callerInterface) { outputs.push(num); } expect(outputs.length).toEqual(value); @@ -135,13 +138,14 @@ describe('RPC', () => { number, number >(methodName, {}); + const writer = callerInterface.writable.getWriter(); for (const value of values) { - await callerInterface.write(value); + await writer.write(value); } - await callerInterface.end(); + await writer.close(); const expectedResult = values.reduce((p, c) => p + c); - await expect(callerInterface.result).resolves.toEqual(expectedResult); + await expect(callerInterface.output).resolves.toEqual(expectedResult); await rpcServer.destroy(); await rpcClient.destroy(); }, diff --git a/tests/RPC/RPCClient.test.ts b/tests/RPC/RPCClient.test.ts index 25d2215d1..5e29dc5b7 100644 --- a/tests/RPC/RPCClient.test.ts +++ b/tests/RPC/RPCClient.test.ts @@ -35,14 +35,16 @@ describe(`${RPCClient.name}`, () => { JSONValue, JSONValue >(methodName, { hello: 'world' }); + const reader = callerInterface.readable.getReader(); + const writer = callerInterface.writable.getWriter(); while (true) { - const { value, done } = await callerInterface.read(); + const { value, done } = await reader.read(); if (done) { // We have to end the writer otherwise the stream never closes - await callerInterface.end(); + await writer.close(); break; } - await callerInterface.write(value); + await writer.write(value); } const expectedMessages: Array = messages.map((v) => { const request: JsonRpcRequestMessage = { @@ -78,7 +80,7 @@ describe(`${RPCClient.name}`, () => { JSONValue >(methodName, params as JSONValue, {}); const values: Array = []; - for await (const value of callerInterface.outputGenerator) { + for await (const value of callerInterface) { values.push(value); } const expectedValues = messages.map((v) => v.result); @@ -112,11 +114,12 @@ describe(`${RPCClient.name}`, () => { JSONValue, JSONValue >(methodName, {}); + const writer = callerInterface.writable.getWriter(); for (const param of params) { - await callerInterface.write(param as JSONValue); + await writer.write(param as JSONValue); } - await callerInterface.end(); - expect(await callerInterface.result).toStrictEqual(message.result); + await writer.close(); + expect(await callerInterface.output).toStrictEqual(message.result); const expectedOutput = params.map((v) => JSON.stringify({ method: methodName, @@ -162,7 +165,6 @@ describe(`${RPCClient.name}`, () => { await rpcClient.destroy(); }, ); - testProp( 'generic duplex caller can throw received error message', [ @@ -188,14 +190,14 @@ describe(`${RPCClient.name}`, () => { JSONValue >(methodName, { hello: 'world' }); const consumeToError = async () => { - for await (const _ of callerInterface.outputGenerator) { + for await (const _ of callerInterface.readable) { // No touch, just consume } }; await expect(consumeToError()).rejects.toThrow( rpcErrors.ErrorRpcRemoteError, ); - await callerInterface.end(); + await callerInterface.writable.close(); await outputResult; await rpcClient.destroy(); }, From 3f12ea892445cfe29fec72b75df02772c29f1b4b Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Mon, 23 Jan 2023 14:15:15 +1100 Subject: [PATCH 19/44] feat: implementing withXCaller CO style methods Related #501 [ci skip] --- src/RPC/RPCClient.ts | 58 +++++++++++++++++++ tests/RPC/RPCClient.test.ts | 112 ++++++++++++++++++++++++++++++++++++ 2 files changed, 170 insertions(+) diff --git a/src/RPC/RPCClient.ts b/src/RPC/RPCClient.ts index bc7ea8bfd..d45f6ae38 100644 --- a/src/RPC/RPCClient.ts +++ b/src/RPC/RPCClient.ts @@ -130,6 +130,64 @@ class RPCClient { await writer.close(); return output.value; } + + @ready(new rpcErrors.ErrorRpcDestroyed()) + public async withDuplexCaller( + method: string, + f: (output: AsyncGenerator) => AsyncGenerator, + metadata: POJO, + ): Promise { + const callerInterface = await this.duplexStreamCaller( + method, + metadata, + ); + const outputGenerator = async function* () { + for await (const value of callerInterface.readable) { + yield value; + } + }; + const writer = callerInterface.writable.getWriter(); + for await (const value of f(outputGenerator())) { + await writer.write(value); + } + await writer.close(); + } + + @ready(new rpcErrors.ErrorRpcDestroyed()) + public async withServerCaller( + method: string, + parameters: I, + f: (output: AsyncGenerator) => Promise, + metadata: POJO, + ) { + const callerInterface = await this.serverStreamCaller( + method, + parameters, + metadata, + ); + const outputGenerator = async function* () { + yield* callerInterface; + }; + await f(outputGenerator()); + } + + @ready(new rpcErrors.ErrorRpcDestroyed()) + public async withClientCaller( + method: string, + f: () => AsyncGenerator, + metadata: POJO, + ): Promise { + const callerInterface = await this.clientStreamCaller( + method, + metadata, + ); + const writer = callerInterface.writable.getWriter(); + for await (const value of f()) { + await writer.write(value); + } + await writer.close(); + return callerInterface.output; + } } export default RPCClient; diff --git a/tests/RPC/RPCClient.test.ts b/tests/RPC/RPCClient.test.ts index 5e29dc5b7..91afceb97 100644 --- a/tests/RPC/RPCClient.test.ts +++ b/tests/RPC/RPCClient.test.ts @@ -202,4 +202,116 @@ describe(`${RPCClient.name}`, () => { await rpcClient.destroy(); }, ); + testProp( + 'withDuplexCaller', + [fc.array(rpcTestUtils.jsonRpcResponseResultArb(), { minLength: 1 })], + async (messages) => { + const inputStream = rpcTestUtils.jsonRpcStream(messages); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const streamPair: ReadableWritablePair = { + readable: inputStream, + writable: outputStream, + }; + const rpcClient = await RPCClient.createRPCClient({ + streamPairCreateCallback: async () => streamPair, + logger, + }); + let count = 0; + await rpcClient.withDuplexCaller( + methodName, + async function* (output) { + for await (const value of output) { + count += 1; + yield value; + } + }, + {}, + ); + const result = await outputResult; + // We're just checking that it consuming the messages as expected + expect(result.length).toEqual(messages.length); + expect(count).toEqual(messages.length); + await rpcClient.destroy(); + }, + ); + testProp( + 'withServerCaller', + [ + fc.array(rpcTestUtils.jsonRpcResponseResultArb(), { minLength: 1 }), + rpcTestUtils.safeJsonValueArb, + ], + async (messages, params) => { + const inputStream = rpcTestUtils.jsonRpcStream(messages); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const streamPair: ReadableWritablePair = { + readable: inputStream, + writable: outputStream, + }; + const rpcClient = await RPCClient.createRPCClient({ + streamPairCreateCallback: async () => streamPair, + logger, + }); + let count = 0; + await rpcClient.withServerCaller( + methodName, + params, + async (output) => { + for await (const _ of output) count += 1; + }, + {}, + ); + const result = await outputResult; + expect(count).toEqual(messages.length); + expect(result.toString()).toStrictEqual( + JSON.stringify({ + method: methodName, + jsonrpc: '2.0', + id: null, + params: params, + }), + ); + await rpcClient.destroy(); + }, + ); + testProp( + 'withClientCaller', + [ + rpcTestUtils.jsonRpcResponseResultArb(), + fc.array(rpcTestUtils.safeJsonValueArb, { minLength: 2 }).noShrink(), + ], + async (message, inputMessages) => { + const inputStream = rpcTestUtils.jsonRpcStream([message]); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const streamPair: ReadableWritablePair = { + readable: inputStream, + writable: outputStream, + }; + const rpcClient = await RPCClient.createRPCClient({ + streamPairCreateCallback: async () => streamPair, + logger, + }); + const result = await rpcClient.withClientCaller( + methodName, + async function* () { + for (const inputMessage of inputMessages) { + yield inputMessage; + } + }, + {}, + ); + const expectedResult = inputMessages.map((v) => { + return JSON.stringify({ + method: methodName, + jsonrpc: '2.0', + id: null, + params: v, + }); + }); + expect((await outputResult).map((v) => v.toString())).toStrictEqual( + expectedResult, + ); + expect(result).toStrictEqual(message.result); + await rpcClient.destroy(); + }, + ); }); From e210e48c3013feadf8be8b9cffde8962ebdb7fef Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Mon, 23 Jan 2023 18:21:04 +1100 Subject: [PATCH 20/44] feat: fleshing out error handling Related #500 [ci skip] --- src/RPC/RPCServer.ts | 59 +++++++++++++++++++-------- src/RPC/errors.ts | 6 +++ src/RPC/utils.ts | 24 +++++++++++ tests/RPC/RPCServer.test.ts | 80 +++++++++++++++++++++++++++++++++++++ 4 files changed, 153 insertions(+), 16 deletions(-) diff --git a/src/RPC/RPCServer.ts b/src/RPC/RPCServer.ts index 5d83d3134..b374de488 100644 --- a/src/RPC/RPCServer.ts +++ b/src/RPC/RPCServer.ts @@ -11,6 +11,7 @@ import type { ReadableWritablePair } from 'stream/web'; import type { JSONValue, POJO } from '../types'; import type { ConnectionInfo } from '../network/types'; import type { UnaryHandler } from './types'; +import type { RPCErrorEvent } from './utils'; import { ReadableStream } from 'stream/web'; import { CreateDestroy, ready } from '@matrixai/async-init/dist/CreateDestroy'; import Logger from '@matrixai/logger'; @@ -42,8 +43,8 @@ class RPCServer { protected logger: Logger; protected handlerMap: Map> = new Map(); - private activeStreams: Set> = new Set(); - private events: EventTarget = new EventTarget(); + protected activeStreams: Set> = new Set(); + protected events: EventTarget = new EventTarget(); public constructor({ container, @@ -164,6 +165,7 @@ class RPCServer { const container = this.container; const handlerMap = this.handlerMap; const ctx = { signal: abortController.signal }; + const events = this.events; const outputGen = async function* (): AsyncGenerator { // Step 1, authentication and establishment // read the first message, lets assume the first message is always leading @@ -207,20 +209,29 @@ class RPCServer { yield responseMessage; } } catch (e) { - // This would be an error from the handler or the streams. We should - // catch this and send an error message back through the stream. - const rpcError: JsonRpcError = { - code: e.exitCode, - message: e.description, - data: rpcUtils.fromError(e), - }; - const rpcErrorMessage: JsonRpcResponseError = { - jsonrpc: '2.0', - error: rpcError, - id: null, - }; - // TODO: catch this and emit error in the event emitter - yield rpcErrorMessage; + if (rpcUtils.isReturnableError(e)) { + // We want to convert this error to an error message and pass it along + const rpcError: JsonRpcError = { + code: e.exitCode, + message: e.description, + data: rpcUtils.fromError(e), + }; + const rpcErrorMessage: JsonRpcResponseError = { + jsonrpc: '2.0', + error: rpcError, + id: null, + }; + yield rpcErrorMessage; + } else { + // These errors are emitted to the event system + events.dispatchEvent( + new rpcUtils.RPCErrorEvent({ + detail: { + error: e, + }, + }), + ); + } } resolve(); }; @@ -245,6 +256,22 @@ class RPCServer { .pipeTo(streamPair.writable) .catch(() => {}); } + + public addEventListener( + type: 'error', + callback: (event: RPCErrorEvent) => void, + options?: boolean | AddEventListenerOptions | undefined, + ) { + this.events.addEventListener(type, callback, options); + } + + public removeEventListener( + type: 'error', + callback: (event: RPCErrorEvent) => void, + options?: boolean | AddEventListenerOptions | undefined, + ) { + this.events.removeEventListener(type, callback, options); + } } export default RPCServer; diff --git a/src/RPC/errors.ts b/src/RPC/errors.ts index e47722205..a5ba14ae9 100644 --- a/src/RPC/errors.ts +++ b/src/RPC/errors.ts @@ -47,6 +47,11 @@ class ErrorRpcRemoteError extends ErrorRpc { exitCode = sysexits.UNAVAILABLE; } +class ErrorRpcPlaceholderConnectionError extends ErrorRpc { + static description = 'placeholder error for connection stream failure'; + exitCode = sysexits.UNAVAILABLE; +} + export { ErrorRpc, ErrorRpcRunning, @@ -58,4 +63,5 @@ export { ErrorRpcProtocal, ErrorRpcMessageLength, ErrorRpcRemoteError, + ErrorRpcPlaceholderConnectionError, }; diff --git a/src/RPC/utils.ts b/src/RPC/utils.ts index 3c6acdcd6..8e5ef2df1 100644 --- a/src/RPC/utils.ts +++ b/src/RPC/utils.ts @@ -475,6 +475,28 @@ class ClientOutputTransformerStream< } } +function isReturnableError(e: Error): boolean { + if (e instanceof rpcErrors.ErrorRpcPlaceholderConnectionError) return false; + return true; +} + +class RPCErrorEvent extends Event { + public detail: { + error: any; + }; + + constructor( + options: EventInit & { + detail: { + error: any; + }; + }, + ) { + super('error', options); + this.detail = options.detail; + } +} + export { JsonToJsonMessageStream, JsonMessageToJsonStream, @@ -489,4 +511,6 @@ export { toError, ClientInputTransformerStream, ClientOutputTransformerStream, + isReturnableError, + RPCErrorEvent, }; diff --git a/tests/RPC/RPCServer.test.ts b/tests/RPC/RPCServer.test.ts index 86e224afa..4481d5cc6 100644 --- a/tests/RPC/RPCServer.test.ts +++ b/tests/RPC/RPCServer.test.ts @@ -12,6 +12,7 @@ import type { ReadableWritablePair } from 'stream/web'; import { testProp, fc } from '@fast-check/jest'; import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; import RPCServer from '@/RPC/RPCServer'; +import * as rpcErrors from '@/RPC/errors'; import * as rpcTestUtils from './utils'; describe(`${RPCServer.name}`, () => { @@ -289,6 +290,85 @@ describe(`${RPCServer.name}`, () => { await rpcServer.destroy(); }); + const errorArb = fc.oneof( + fc.constant(new rpcErrors.ErrorRpcParse()), + fc.constant(new rpcErrors.ErrorRpcHandlerMissing()), + fc.constant(new rpcErrors.ErrorRpcProtocal()), + fc.constant(new rpcErrors.ErrorRpcMessageLength()), + fc.constant(new rpcErrors.ErrorRpcRemoteError()), + ); + testProp( + 'should send error message', + [specificMessageArb, errorArb], + async (messages, error) => { + const stream = rpcTestUtils.jsonRpcStream(messages); + const container = {}; + const rpcServer = await RPCServer.createRPCServer({ container, logger }); + let resolve, reject; + const errorProm = new Promise((resolve_, reject_) => { + resolve = resolve_; + reject = reject_; + }); + rpcServer.addEventListener('error', (thing) => { + resolve(thing); + }); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const readWriteStream: ReadableWritablePair = { + readable: stream, + writable: outputStream, + }; + + const duplexHandler: DuplexStreamHandler = + async function* (_input, _container, _connectionInfo, _ctx) { + throw error; + }; + + rpcServer.registerDuplexStreamHandler(methodName, duplexHandler); + rpcServer.handleStream(readWriteStream, {} as ConnectionInfo); + const errorMessage = JSON.parse((await outputResult)[0]!.toString()); + expect(errorMessage.error.code).toEqual(error.exitCode); + expect(errorMessage.error.message).toEqual(error.description); + reject(); + await expect(errorProm).toReject(); + await rpcServer.destroy(); + }, + ); + testProp( + 'should emit stream error', + [specificMessageArb], + async (messages) => { + const stream = rpcTestUtils.jsonRpcStream(messages); + const container = {}; + const rpcServer = await RPCServer.createRPCServer({ container, logger }); + let resolve, reject; + const errorProm = new Promise((resolve_, reject_) => { + resolve = resolve_; + reject = reject_; + }); + rpcServer.addEventListener('error', (thing) => { + resolve(thing); + }); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const readWriteStream: ReadableWritablePair = { + readable: stream, + writable: outputStream, + }; + + const duplexHandler: DuplexStreamHandler = + async function* (_input, _container, _connectionInfo, _ctx) { + throw new rpcErrors.ErrorRpcPlaceholderConnectionError(); + }; + + rpcServer.registerDuplexStreamHandler(methodName, duplexHandler); + rpcServer.handleStream(readWriteStream, {} as ConnectionInfo); + await outputResult; + + await rpcServer.destroy(); + reject(); + await expect(errorProm).toResolve(); + }, + ); + // TODO: // - Test odd conditions for handlers, like extra messages where 1 is expected. // - Expectations can't be inside the handlers otherwise they're caught. From 260f23b9a35c2ac67430dfd58247719b36d207c2 Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Mon, 23 Jan 2023 18:48:36 +1100 Subject: [PATCH 21/44] feat: updated streams to use `Uint8Array` instead of `Buffer` Mostly a type change, `Buffer` just extended `Uint8Array`. Related #500 [ci skip] --- src/RPC/RPCServer.ts | 2 +- src/RPC/types.ts | 2 +- src/RPC/utils.ts | 15 +++++++++------ tests/RPC/RPC.test.ts | 16 ++++++++-------- tests/RPC/RPCClient.test.ts | 18 ++++++++++++------ tests/RPC/RPCServer.test.ts | 2 +- tests/RPC/utils.ts | 18 +++++++++--------- 7 files changed, 41 insertions(+), 32 deletions(-) diff --git a/src/RPC/RPCServer.ts b/src/RPC/RPCServer.ts index b374de488..23e071a0f 100644 --- a/src/RPC/RPCServer.ts +++ b/src/RPC/RPCServer.ts @@ -133,7 +133,7 @@ class RPCServer { @ready(new rpcErrors.ErrorRpcDestroyed()) public handleStream( - streamPair: ReadableWritablePair, + streamPair: ReadableWritablePair, connectionInfo: ConnectionInfo, ) { // This will take a buffer stream of json messages and set up service diff --git a/src/RPC/types.ts b/src/RPC/types.ts index 884709b68..6f5797a0a 100644 --- a/src/RPC/types.ts +++ b/src/RPC/types.ts @@ -124,7 +124,7 @@ type UnaryHandler = Handler< >; type StreamPairCreateCallback = () => Promise< - ReadableWritablePair + ReadableWritablePair >; export type { diff --git a/src/RPC/utils.ts b/src/RPC/utils.ts index 8e5ef2df1..d3f50ce14 100644 --- a/src/RPC/utils.ts +++ b/src/RPC/utils.ts @@ -23,7 +23,7 @@ import * as errors from '../errors'; const jsonStreamParsers = require('@streamparser/json'); class JsonToJsonMessage - implements Transformer + implements Transformer { protected bytesWritten: number = 0; @@ -45,7 +45,7 @@ class JsonToJsonMessage }; }; - transform: TransformerTransformCallback = async (chunk) => { + transform: TransformerTransformCallback = async (chunk) => { try { this.bytesWritten += chunk.byteLength; this.parser.write(chunk); @@ -60,7 +60,7 @@ class JsonToJsonMessage // TODO: rename to something more descriptive? class JsonToJsonMessageStream extends TransformStream< - Buffer, + Uint8Array, T > { constructor( @@ -71,8 +71,8 @@ class JsonToJsonMessageStream extends TransformStream< } } -class JsonMessageToJson implements Transformer { - transform: TransformerTransformCallback = async ( +class JsonMessageToJson implements Transformer { + transform: TransformerTransformCallback = async ( chunk, controller, ) => { @@ -81,7 +81,10 @@ class JsonMessageToJson implements Transformer { } // TODO: rename to something more descriptive? -class JsonMessageToJsonStream extends TransformStream { +class JsonMessageToJsonStream extends TransformStream< + JsonRpcMessage, + Uint8Array +> { constructor() { super(new JsonMessageToJson()); } diff --git a/tests/RPC/RPC.test.ts b/tests/RPC/RPC.test.ts index 48edbb5e5..1f5220d55 100644 --- a/tests/RPC/RPC.test.ts +++ b/tests/RPC/RPC.test.ts @@ -22,8 +22,8 @@ describe('RPC', () => { [fc.array(rpcTestUtils.safeJsonValueArb, { minLength: 1 })], async (values) => { const { clientPair, serverPair } = rpcTestUtils.createTapPairs< - Buffer, - Buffer + Uint8Array, + Uint8Array >(); const container = {}; @@ -68,8 +68,8 @@ describe('RPC', () => { [fc.integer({ min: 1, max: 100 })], async (value) => { const { clientPair, serverPair } = rpcTestUtils.createTapPairs< - Buffer, - Buffer + Uint8Array, + Uint8Array >(); const container = {}; @@ -110,8 +110,8 @@ describe('RPC', () => { [fc.array(fc.integer(), { minLength: 1 }).noShrink()], async (values) => { const { clientPair, serverPair } = rpcTestUtils.createTapPairs< - Buffer, - Buffer + Uint8Array, + Uint8Array >(); const container = {}; @@ -156,8 +156,8 @@ describe('RPC', () => { [rpcTestUtils.safeJsonValueArb], async (value) => { const { clientPair, serverPair } = rpcTestUtils.createTapPairs< - Buffer, - Buffer + Uint8Array, + Uint8Array >(); const container = {}; diff --git a/tests/RPC/RPCClient.test.ts b/tests/RPC/RPCClient.test.ts index 91afceb97..646972370 100644 --- a/tests/RPC/RPCClient.test.ts +++ b/tests/RPC/RPCClient.test.ts @@ -22,7 +22,8 @@ describe(`${RPCClient.name}`, () => { testProp('generic duplex caller', [specificMessageArb], async (messages) => { const inputStream = rpcTestUtils.jsonRpcStream(messages); - const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const [outputResult, outputStream] = + rpcTestUtils.streamToArray(); const streamPair: ReadableWritablePair = { readable: inputStream, writable: outputStream, @@ -101,7 +102,8 @@ describe(`${RPCClient.name}`, () => { [rpcTestUtils.jsonRpcResponseResultArb(), fc.array(fc.jsonValue())], async (message, params) => { const inputStream = rpcTestUtils.jsonRpcStream([message]); - const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const [outputResult, outputStream] = + rpcTestUtils.streamToArray(); const streamPair: ReadableWritablePair = { readable: inputStream, writable: outputStream, @@ -176,7 +178,8 @@ describe(`${RPCClient.name}`, () => { ...messages, errorMessage, ]); - const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const [outputResult, outputStream] = + rpcTestUtils.streamToArray(); const streamPair: ReadableWritablePair = { readable: inputStream, writable: outputStream, @@ -207,7 +210,8 @@ describe(`${RPCClient.name}`, () => { [fc.array(rpcTestUtils.jsonRpcResponseResultArb(), { minLength: 1 })], async (messages) => { const inputStream = rpcTestUtils.jsonRpcStream(messages); - const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const [outputResult, outputStream] = + rpcTestUtils.streamToArray(); const streamPair: ReadableWritablePair = { readable: inputStream, writable: outputStream, @@ -242,7 +246,8 @@ describe(`${RPCClient.name}`, () => { ], async (messages, params) => { const inputStream = rpcTestUtils.jsonRpcStream(messages); - const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const [outputResult, outputStream] = + rpcTestUtils.streamToArray(); const streamPair: ReadableWritablePair = { readable: inputStream, writable: outputStream, @@ -281,7 +286,8 @@ describe(`${RPCClient.name}`, () => { ], async (message, inputMessages) => { const inputStream = rpcTestUtils.jsonRpcStream([message]); - const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const [outputResult, outputStream] = + rpcTestUtils.streamToArray(); const streamPair: ReadableWritablePair = { readable: inputStream, writable: outputStream, diff --git a/tests/RPC/RPCServer.test.ts b/tests/RPC/RPCServer.test.ts index 4481d5cc6..e1b0685eb 100644 --- a/tests/RPC/RPCServer.test.ts +++ b/tests/RPC/RPCServer.test.ts @@ -231,7 +231,7 @@ describe(`${RPCServer.name}`, () => { let thing; let lastMessage: JsonRpcMessage | undefined; const tapStream: any = {}; - // Const tapStream = new rpcTestUtils.TapStream( + // Const tapStream = new rpcTestUtils.TapStream( // async (_, iteration) => { // if (iteration === 2) { // // @ts-ignore: kidnap private property diff --git a/tests/RPC/utils.ts b/tests/RPC/utils.ts index b11ae020e..e77b4e9a6 100644 --- a/tests/RPC/utils.ts +++ b/tests/RPC/utils.ts @@ -21,7 +21,7 @@ import { fc } from '@fast-check/jest'; import * as utils from '@/utils'; import { fromError } from '@/RPC/utils'; -class BufferStreamToSnipped implements Transformer { +class BufferStreamToSnipped implements Transformer { protected buffer = Buffer.alloc(0); protected iteration = 0; protected snippingPattern: Array; @@ -30,7 +30,7 @@ class BufferStreamToSnipped implements Transformer { this.snippingPattern = snippingPattern; } - transform: TransformerTransformCallback = async ( + transform: TransformerTransformCallback = async ( chunk, controller, ) => { @@ -46,7 +46,7 @@ class BufferStreamToSnipped implements Transformer { } }; - flush: TransformerFlushCallback = (controller) => { + flush: TransformerFlushCallback = (controller) => { controller.enqueue(this.buffer); }; } @@ -62,15 +62,15 @@ class BufferStreamToSnippedStream extends TransformStream { } } -class BufferStreamToNoisy implements Transformer { +class BufferStreamToNoisy implements Transformer { protected iteration = 0; - protected noise: Array; + protected noise: Array; - constructor(noise: Array) { + constructor(noise: Array) { this.noise = noise; } - transform: TransformerTransformCallback = async ( + transform: TransformerTransformCallback = async ( chunk, controller, ) => { @@ -87,13 +87,13 @@ class BufferStreamToNoisy implements Transformer { * splitting up the data. */ class BufferStreamToNoisyStream extends TransformStream { - constructor(noise: Array) { + constructor(noise: Array) { super(new BufferStreamToNoisy(noise)); } } const jsonRpcStream = (messages: Array) => { - return new ReadableStream({ + return new ReadableStream({ async start(controller) { for (const arrayElement of messages) { // Controller.enqueue(arrayElement) From 74450a91ec06b780164dae9b2a22bcb4ff86d222 Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Tue, 24 Jan 2023 19:36:06 +1100 Subject: [PATCH 22/44] feat: middleware Related #502 Related #500 Related #502 [ci skip] --- src/RPC/RPCClient.ts | 61 +++++++++++-- src/RPC/RPCServer.ts | 177 ++++++++++++++++++++++++------------ src/RPC/errors.ts | 7 +- src/RPC/types.ts | 11 +++ src/RPC/utils.ts | 94 ++++++++++++++++--- tests/RPC/RPCClient.test.ts | 125 ++++++++++++++++++++++++- tests/RPC/RPCServer.test.ts | 174 ++++++++++++++++++++++++++++++++++- tests/RPC/utils.test.ts | 31 +++++++ 8 files changed, 599 insertions(+), 81 deletions(-) diff --git a/src/RPC/RPCClient.ts b/src/RPC/RPCClient.ts index d45f6ae38..41109f76a 100644 --- a/src/RPC/RPCClient.ts +++ b/src/RPC/RPCClient.ts @@ -1,6 +1,12 @@ import type { StreamPairCreateCallback } from './types'; import type { JSONValue, POJO } from 'types'; import type { ReadableWritablePair } from 'stream/web'; +import type { + JsonRpcRequest, + JsonRpcResponse, + MiddlewareFactory, + Middleware, +} from './types'; import { CreateDestroy, ready } from '@matrixai/async-init/dist/CreateDestroy'; import Logger from '@matrixai/logger'; import * as rpcErrors from './errors'; @@ -50,14 +56,24 @@ class RPCClient { _metadata: POJO, ): Promise> { const streamPair = await this.streamPairCreateCallback(); - const outputStream = streamPair.readable - .pipeThrough( - new rpcUtils.JsonToJsonMessageStream(rpcUtils.parseJsonRpcResponse), - ) - .pipeThrough(new rpcUtils.ClientOutputTransformerStream()); + let reverseMiddlewareStream = streamPair.readable.pipeThrough( + new rpcUtils.JsonToJsonMessageStream(rpcUtils.parseJsonRpcResponse), + ); + for (const middleWare of this.reverseMiddleware) { + const middle = middleWare(); + reverseMiddlewareStream = middle(reverseMiddlewareStream); + } + const outputStream = reverseMiddlewareStream.pipeThrough( + new rpcUtils.ClientOutputTransformerStream(), + ); const inputMessageTransformer = new rpcUtils.ClientInputTransformerStream(method); - void inputMessageTransformer.readable + let forwardMiddlewareStream = inputMessageTransformer.readable; + for (const middleware of this.forwardMiddleWare) { + const middle = middleware(); + forwardMiddlewareStream = middle(forwardMiddlewareStream); + } + void forwardMiddlewareStream .pipeThrough(new rpcUtils.JsonMessageToJsonStream()) .pipeTo(streamPair.writable) .catch(() => {}); @@ -188,6 +204,39 @@ class RPCClient { await writer.close(); return callerInterface.output; } + + protected forwardMiddleWare: Array< + MiddlewareFactory>> + > = []; + protected reverseMiddleware: Array< + MiddlewareFactory>> + > = []; + + @ready(new rpcErrors.ErrorRpcDestroyed()) + public registerForwardMiddleware( + middlewareFactory: MiddlewareFactory>>, + ) { + this.forwardMiddleWare.push(middlewareFactory); + } + + @ready(new rpcErrors.ErrorRpcDestroyed()) + public clearForwardMiddleware() { + this.reverseMiddleware = []; + } + + @ready(new rpcErrors.ErrorRpcDestroyed()) + public registerReverseMiddleware( + middlewareFactory: MiddlewareFactory< + Middleware> + >, + ) { + this.reverseMiddleware.push(middlewareFactory); + } + + @ready(new rpcErrors.ErrorRpcDestroyed()) + public clearReverseMiddleware() { + this.reverseMiddleware = []; + } } export default RPCClient; diff --git a/src/RPC/RPCServer.ts b/src/RPC/RPCServer.ts index 23e071a0f..7056a823a 100644 --- a/src/RPC/RPCServer.ts +++ b/src/RPC/RPCServer.ts @@ -1,23 +1,29 @@ import type { - ServerStreamHandler, + ClientStreamHandler, DuplexStreamHandler, JsonRpcError, - JsonRpcMessage, + JsonRpcRequest, + JsonRpcResponse, JsonRpcResponseError, JsonRpcResponseResult, - ClientStreamHandler, + ServerStreamHandler, + UnaryHandler, } from './types'; import type { ReadableWritablePair } from 'stream/web'; import type { JSONValue, POJO } from '../types'; import type { ConnectionInfo } from '../network/types'; -import type { UnaryHandler } from './types'; import type { RPCErrorEvent } from './utils'; +import type { + MiddlewareFactory, + MiddlewareShort, + Middleware, +} from 'tokens/types'; import { ReadableStream } from 'stream/web'; import { CreateDestroy, ready } from '@matrixai/async-init/dist/CreateDestroy'; import Logger from '@matrixai/logger'; import { PromiseCancellable } from '@matrixai/async-cancellable'; -import * as rpcErrors from './errors'; import * as rpcUtils from './utils'; +import * as rpcErrors from './errors'; interface RPCServer extends CreateDestroy {} @CreateDestroy() @@ -152,13 +158,22 @@ class RPCServer { void handlerProm .finally(() => this.activeStreams.delete(handlerProm)) .catch(() => {}); + // Setting up forward middleware + let middlewareStream = streamPair.readable.pipeThrough( + new rpcUtils.JsonToJsonMessageStream(rpcUtils.parseJsonRpcRequest), + ); + const shortMessageQueue: Array = []; + for (const forwardMiddleWareFactory of this.forwardMiddleWare) { + const middleware = forwardMiddleWareFactory(); + middlewareStream = middleware( + middlewareStream, + (value: JsonRpcResponse) => shortMessageQueue.push(value), + ); + } // While ReadableStream can be converted to AsyncIterable, we want it as // a generator. const inputGen = async function* () { - const pojoStream = streamPair.readable.pipeThrough( - new rpcUtils.JsonToJsonMessageStream(rpcUtils.parseJsonRpcRequest), - ); - for await (const dataMessage of pojoStream) { + for await (const dataMessage of middlewareStream) { yield dataMessage; } }; @@ -166,7 +181,9 @@ class RPCServer { const handlerMap = this.handlerMap; const ctx = { signal: abortController.signal }; const events = this.events; - const outputGen = async function* (): AsyncGenerator { + const outputGen = async function* (): AsyncGenerator< + JsonRpcResponse + > { // Step 1, authentication and establishment // read the first message, lets assume the first message is always leading // metadata. @@ -184,7 +201,6 @@ class RPCServer { yield data.params as JSONValue; } }; - // TODO: validation on metadata const handler = handlerMap.get(method); if (handler == null) { // Failed to find handler, this is an error. We should respond with @@ -194,69 +210,80 @@ class RPCServer { ); } if (ctx.signal.aborted) throw ctx.signal.reason; - try { - for await (const response of handler( - dataGen(), - container, - connectionInfo, - ctx, - )) { - const responseMessage: JsonRpcResponseResult = { - jsonrpc: '2.0', - result: response, - id: null, - }; - yield responseMessage; - } - } catch (e) { - if (rpcUtils.isReturnableError(e)) { - // We want to convert this error to an error message and pass it along - const rpcError: JsonRpcError = { - code: e.exitCode, - message: e.description, - data: rpcUtils.fromError(e), - }; - const rpcErrorMessage: JsonRpcResponseError = { - jsonrpc: '2.0', - error: rpcError, - id: null, - }; - yield rpcErrorMessage; - } else { - // These errors are emitted to the event system - events.dispatchEvent( - new rpcUtils.RPCErrorEvent({ - detail: { - error: e, - }, - }), - ); - } + for await (const response of handler( + dataGen(), + container, + connectionInfo, + ctx, + )) { + const responseMessage: JsonRpcResponseResult = { + jsonrpc: '2.0', + result: response, + id: null, + }; + yield responseMessage; } - resolve(); }; const outputGenerator = outputGen(); - const outputStream = new ReadableStream({ + let reverseMiddlewareStream = new ReadableStream< + JsonRpcResponse + >({ pull: async (controller) => { - const { value, done } = await outputGenerator.next(); - if (done) { + try { + const { value, done } = await outputGenerator.next(); + if (done) { + controller.close(); + resolve(); + return; + } + controller.enqueue(value); + } catch (e) { + if (rpcUtils.isReturnableError(e)) { + // We want to convert this error to an error message and pass it along + const rpcError: JsonRpcError = { + code: e.exitCode, + message: e.description, + data: rpcUtils.fromError(e), + }; + const rpcErrorMessage: JsonRpcResponseError = { + jsonrpc: '2.0', + error: rpcError, + id: null, + }; + controller.enqueue(rpcErrorMessage); + } else { + // These errors are emitted to the event system + events.dispatchEvent( + new rpcUtils.RPCErrorEvent({ + detail: { + error: e, + }, + }), + ); + } controller.close(); - return; + resolve(); } - controller.enqueue(value); }, cancel: async (reason) => { await outputGenerator.throw(reason); }, }); - void outputStream + // Setting up reverse middleware + for (const reverseMiddleWareFactory of this.reverseMiddleware) { + const middleware = reverseMiddleWareFactory(); + reverseMiddlewareStream = middleware(reverseMiddlewareStream); + } + reverseMiddlewareStream + .pipeThrough(new rpcUtils.QueueMergingTransformStream(shortMessageQueue)) .pipeThrough(new rpcUtils.JsonMessageToJsonStream()) .pipeTo(streamPair.writable) .catch(() => {}); } + @ready(new rpcErrors.ErrorRpcDestroyed()) public addEventListener( type: 'error', callback: (event: RPCErrorEvent) => void, @@ -265,6 +292,7 @@ class RPCServer { this.events.addEventListener(type, callback, options); } + @ready(new rpcErrors.ErrorRpcDestroyed()) public removeEventListener( type: 'error', callback: (event: RPCErrorEvent) => void, @@ -272,6 +300,43 @@ class RPCServer { ) { this.events.removeEventListener(type, callback, options); } + + protected forwardMiddleWare: Array< + MiddlewareFactory< + MiddlewareShort, JsonRpcResponse> + > + > = []; + protected reverseMiddleware: Array< + MiddlewareFactory>> + > = []; + + @ready(new rpcErrors.ErrorRpcDestroyed()) + public registerForwardMiddleware( + middlewareFactory: MiddlewareFactory< + MiddlewareShort, JsonRpcResponse> + >, + ) { + this.forwardMiddleWare.push(middlewareFactory); + } + + @ready(new rpcErrors.ErrorRpcDestroyed()) + public clearForwardMiddleware() { + this.reverseMiddleware = []; + } + + @ready(new rpcErrors.ErrorRpcDestroyed()) + public registerReverseMiddleware( + middlewareFactory: MiddlewareFactory< + Middleware> + >, + ) { + this.reverseMiddleware.push(middlewareFactory); + } + + @ready(new rpcErrors.ErrorRpcDestroyed()) + public clearReverseMiddleware() { + this.reverseMiddleware = []; + } } export default RPCServer; diff --git a/src/RPC/errors.ts b/src/RPC/errors.ts index a5ba14ae9..d0e2b8e7e 100644 --- a/src/RPC/errors.ts +++ b/src/RPC/errors.ts @@ -47,7 +47,11 @@ class ErrorRpcRemoteError extends ErrorRpc { exitCode = sysexits.UNAVAILABLE; } -class ErrorRpcPlaceholderConnectionError extends ErrorRpc { +class ErrorRpcNoMessageError extends ErrorRpc { + static description = 'For errors not to be conveyed to the client'; +} + +class ErrorRpcPlaceholderConnectionError extends ErrorRpcNoMessageError { static description = 'placeholder error for connection stream failure'; exitCode = sysexits.UNAVAILABLE; } @@ -63,5 +67,6 @@ export { ErrorRpcProtocal, ErrorRpcMessageLength, ErrorRpcRemoteError, + ErrorRpcNoMessageError, ErrorRpcPlaceholderConnectionError, }; diff --git a/src/RPC/types.ts b/src/RPC/types.ts index 6f5797a0a..a233da2df 100644 --- a/src/RPC/types.ts +++ b/src/RPC/types.ts @@ -2,6 +2,7 @@ import type { JSONValue, POJO } from '../types'; import type { ConnectionInfo } from '../network/types'; import type { ContextCancellable } from '../contexts/types'; import type { ReadableWritablePair } from 'stream/web'; +import type { ReadableStream } from 'stream/web'; /** * This is the JSON RPC request object. this is the generic message type used for the RPC. @@ -127,6 +128,13 @@ type StreamPairCreateCallback = () => Promise< ReadableWritablePair >; +type MiddlewareShort = ( + input: ReadableStream, + short: (value: K) => void, +) => ReadableStream; +type Middleware = (input: ReadableStream) => ReadableStream; +type MiddlewareFactory = () => T; + export type { JsonRpcRequestMessage, JsonRpcRequestNotification, @@ -141,4 +149,7 @@ export type { ClientStreamHandler, UnaryHandler, StreamPairCreateCallback, + MiddlewareShort, + Middleware, + MiddlewareFactory, }; diff --git a/src/RPC/utils.ts b/src/RPC/utils.ts index d3f50ce14..9c87232c7 100644 --- a/src/RPC/utils.ts +++ b/src/RPC/utils.ts @@ -2,6 +2,7 @@ import type { Transformer, TransformerTransformCallback, TransformerStartCallback, + TransformerFlushCallback, } from 'stream/web'; import type { JsonRpcError, @@ -14,12 +15,14 @@ import type { JsonRpcResponse, } from 'RPC/types'; import type { JSONValue } from '../types'; +import type { JsonValue } from 'fast-check'; import { TransformStream } from 'stream/web'; import { AbstractError } from '@matrixai/errors'; import * as rpcErrors from './errors'; import * as utils from '../utils'; import * as validationErrors from '../validation/errors'; import * as errors from '../errors'; +import { promise } from '../utils'; const jsonStreamParsers = require('@streamparser/json'); class JsonToJsonMessage @@ -429,27 +432,25 @@ function toError(errorData) { } class ClientInputTransformer - implements Transformer + implements Transformer> { constructor(protected method: string) {} - transform: TransformerTransformCallback> = async ( - chunk, - controller, - ) => { - const message: JsonRpcRequestMessage = { - method: this.method, - jsonrpc: '2.0', - id: null, - params: chunk, + transform: TransformerTransformCallback> = + async (chunk, controller) => { + const message: JsonRpcRequest = { + method: this.method, + jsonrpc: '2.0', + id: null, + params: chunk, + }; + controller.enqueue(message); }; - controller.enqueue(message); - }; } class ClientInputTransformerStream extends TransformStream< I, - JsonRpcRequestMessage + JsonRpcRequest > { constructor(method: string) { super(new ClientInputTransformer(method)); @@ -479,7 +480,7 @@ class ClientOutputTransformerStream< } function isReturnableError(e: Error): boolean { - if (e instanceof rpcErrors.ErrorRpcPlaceholderConnectionError) return false; + if (e instanceof rpcErrors.ErrorRpcNoMessageError) return false; return true; } @@ -500,6 +501,69 @@ class RPCErrorEvent extends Event { } } +const controllerTransformationFactory = () => { + const controllerProm = promise>(); + + class ControllerTransform implements Transformer { + start: TransformerStartCallback = async (controller) => { + // @ts-ignore: type mismatch oddity + controllerProm.resolveP(controller); + }; + + transform: TransformerTransformCallback = async ( + chunk, + controller, + ) => { + controller.enqueue(chunk); + }; + } + + class ControllerTransformStream extends TransformStream { + constructor() { + super(new ControllerTransform()); + } + } + return { + controllerP: controllerProm.p, + controllerTransformStream: new ControllerTransformStream(), + }; +}; + +class QueueMergingTransform implements Transformer { + constructor(protected messageQueue: Array) {} + + start: TransformerStartCallback = async (controller) => { + while (true) { + const value = this.messageQueue.shift(); + if (value == null) break; + controller.enqueue(value); + } + }; + + transform: TransformerTransformCallback = async (chunk, controller) => { + while (true) { + const value = this.messageQueue.shift(); + if (value == null) break; + controller.enqueue(value); + } + controller.enqueue(chunk); + }; + + flush: TransformerFlushCallback = (controller) => { + while (true) { + const value = this.messageQueue.shift(); + if (value == null) break; + controller.enqueue(value); + } + }; +} + +class QueueMergingTransformStream extends TransformStream { + constructor(messageQueue: Array) { + super(new QueueMergingTransform(messageQueue)); + } +} + export { JsonToJsonMessageStream, JsonMessageToJsonStream, @@ -516,4 +580,6 @@ export { ClientOutputTransformerStream, isReturnableError, RPCErrorEvent, + controllerTransformationFactory, + QueueMergingTransformStream, }; diff --git a/tests/RPC/RPCClient.test.ts b/tests/RPC/RPCClient.test.ts index 646972370..9ce0d289e 100644 --- a/tests/RPC/RPCClient.test.ts +++ b/tests/RPC/RPCClient.test.ts @@ -1,6 +1,11 @@ import type { ReadableWritablePair } from 'stream/web'; import type { JSONValue } from '@/types'; -import type { JsonRpcRequestMessage } from '@/RPC/types'; +import type { + JsonRpcRequest, + JsonRpcRequestMessage, + JsonRpcResponse, +} from '@/RPC/types'; +import { TransformStream } from 'stream/web'; import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; import { testProp, fc } from '@fast-check/jest'; import RPCClient from '@/RPC/RPCClient'; @@ -320,4 +325,122 @@ describe(`${RPCClient.name}`, () => { await rpcClient.destroy(); }, ); + testProp( + 'generic duplex caller with forward Middleware', + [specificMessageArb], + async (messages) => { + const inputStream = rpcTestUtils.jsonRpcStream(messages); + const [outputResult, outputStream] = + rpcTestUtils.streamToArray(); + const streamPair: ReadableWritablePair = { + readable: inputStream, + writable: outputStream, + }; + const rpcClient = await RPCClient.createRPCClient({ + streamPairCreateCallback: async () => streamPair, + logger, + }); + + rpcClient.registerForwardMiddleware(() => { + return (input) => + input.pipeThrough( + new TransformStream< + JsonRpcRequest, + JsonRpcRequest + >({ + transform: (chunk, controller) => { + controller.enqueue({ + ...chunk, + params: 'one', + }); + }, + }), + ); + }); + const callerInterface = await rpcClient.duplexStreamCaller< + JSONValue, + JSONValue + >(methodName, { hello: 'world' }); + const reader = callerInterface.readable.getReader(); + const writer = callerInterface.writable.getWriter(); + while (true) { + const { value, done } = await reader.read(); + if (done) { + // We have to end the writer otherwise the stream never closes + await writer.close(); + break; + } + await writer.write(value); + } + + const expectedMessages: Array = messages.map( + () => { + const request: JsonRpcRequestMessage = { + jsonrpc: '2.0', + method: methodName, + id: null, + params: 'one', + }; + return request; + }, + ); + const outputMessages = (await outputResult).map((v) => + JSON.parse(v.toString()), + ); + expect(outputMessages).toStrictEqual(expectedMessages); + await rpcClient.destroy(); + }, + ); + testProp.only( + 'generic duplex caller with reverse Middleware', + [specificMessageArb], + async (messages) => { + const inputStream = rpcTestUtils.jsonRpcStream(messages); + const [outputResult, outputStream] = + rpcTestUtils.streamToArray(); + const streamPair: ReadableWritablePair = { + readable: inputStream, + writable: outputStream, + }; + const rpcClient = await RPCClient.createRPCClient({ + streamPairCreateCallback: async () => streamPair, + logger, + }); + + rpcClient.registerReverseMiddleware(() => { + return (input) => + input.pipeThrough( + new TransformStream< + JsonRpcResponse, + JsonRpcResponse + >({ + transform: (chunk, controller) => { + controller.enqueue({ + ...chunk, + result: 'one', + }); + }, + }), + ); + }); + const callerInterface = await rpcClient.duplexStreamCaller< + JSONValue, + JSONValue + >(methodName, { hello: 'world' }); + const reader = callerInterface.readable.getReader(); + const writer = callerInterface.writable.getWriter(); + while (true) { + const { value, done } = await reader.read(); + if (done) { + // We have to end the writer otherwise the stream never closes + await writer.close(); + break; + } + expect(value).toBe('one'); + await writer.write(value); + } + await outputResult; + await rpcClient.destroy(); + }, + ); }); diff --git a/tests/RPC/RPCServer.test.ts b/tests/RPC/RPCServer.test.ts index e1b0685eb..40b9881d8 100644 --- a/tests/RPC/RPCServer.test.ts +++ b/tests/RPC/RPCServer.test.ts @@ -2,6 +2,8 @@ import type { ClientStreamHandler, DuplexStreamHandler, JsonRpcMessage, + JsonRpcRequest, + JsonRpcResponse, ServerStreamHandler, UnaryHandler, } from '@/RPC/types'; @@ -9,7 +11,8 @@ import type { JSONValue } from '@/types'; import type { ConnectionInfo, Host, Port } from '@/network/types'; import type { NodeId } from '@/ids'; import type { ReadableWritablePair } from 'stream/web'; -import { testProp, fc } from '@fast-check/jest'; +import { TransformStream } from 'stream/web'; +import { fc, testProp } from '@fast-check/jest'; import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; import RPCServer from '@/RPC/RPCServer'; import * as rpcErrors from '@/RPC/errors'; @@ -204,17 +207,19 @@ describe(`${RPCServer.name}`, () => { writable: outputStream, }; + let handledConnectionInfo; const duplexHandler: DuplexStreamHandler = async function* (input, _container, connectionInfo_, _ctx) { - expect(connectionInfo_).toBe(connectionInfo); + handledConnectionInfo = connectionInfo_; for await (const val of input) { yield val; } }; rpcServer.registerDuplexStreamHandler(methodName, duplexHandler); - rpcServer.handleStream(readWriteStream, {} as ConnectionInfo); + rpcServer.handleStream(readWriteStream, connectionInfo); await outputResult; await rpcServer.destroy(); + expect(handledConnectionInfo).toBe(connectionInfo); }, ); @@ -368,6 +373,169 @@ describe(`${RPCServer.name}`, () => { await expect(errorProm).toResolve(); }, ); + testProp('forward middlewares', [specificMessageArb], async (messages) => { + const stream = rpcTestUtils.jsonRpcStream(messages); + const container = {}; + const rpcServer = await RPCServer.createRPCServer({ container, logger }); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const readWriteStream: ReadableWritablePair = { + readable: stream, + writable: outputStream, + }; + + const duplexHandler: DuplexStreamHandler = + async function* (input, _container, _connectionInfo, _ctx) { + for await (const val of input) { + yield val; + } + }; + + rpcServer.registerDuplexStreamHandler(methodName, duplexHandler); + rpcServer.registerForwardMiddleware(() => { + return (input) => + input.pipeThrough( + new TransformStream< + JsonRpcRequest, + JsonRpcRequest + >({ + transform: (chunk, controller) => { + chunk.params = 1; + controller.enqueue(chunk); + }, + }), + ); + }); + rpcServer.handleStream(readWriteStream, {} as ConnectionInfo); + const out = await outputResult; + expect(out.map((v) => v!.toString())).toStrictEqual( + messages.map(() => { + return JSON.stringify({ + jsonrpc: '2.0', + result: 1, + id: null, + }); + }), + ); + await rpcServer.destroy(); + }); + testProp('reverse middlewares', [specificMessageArb], async (messages) => { + const stream = rpcTestUtils.jsonRpcStream(messages); + const container = {}; + const rpcServer = await RPCServer.createRPCServer({ container, logger }); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const readWriteStream: ReadableWritablePair = { + readable: stream, + writable: outputStream, + }; + + const duplexHandler: DuplexStreamHandler = + async function* (input, _container, _connectionInfo, _ctx) { + for await (const val of input) { + yield val; + } + }; + + rpcServer.registerDuplexStreamHandler(methodName, duplexHandler); + rpcServer.registerReverseMiddleware(() => { + return (input) => + input.pipeThrough( + new TransformStream< + JsonRpcResponse, + JsonRpcResponse + >({ + transform: (chunk, controller) => { + if ('result' in chunk) { + chunk.result = 1; + } + controller.enqueue(chunk); + }, + }), + ); + }); + rpcServer.handleStream(readWriteStream, {} as ConnectionInfo); + const out = await outputResult; + expect(out.map((v) => v!.toString())).toStrictEqual( + messages.map(() => { + return JSON.stringify({ + jsonrpc: '2.0', + result: 1, + id: null, + }); + }), + ); + await rpcServer.destroy(); + }); + const validToken = 'VALIDTOKEN'; + const invalidTokenMessageArb = rpcTestUtils.jsonRpcRequestMessageArb( + undefined, + fc.record({ + metadata: fc.record({ + token: fc.string().filter((v) => v !== validToken), + }), + data: rpcTestUtils.safeJsonValueArb, + }), + ); + + testProp( + 'forward middleware authentication', + [invalidTokenMessageArb], + async (message) => { + const stream = rpcTestUtils.jsonRpcStream([message]); + const container = {}; + const rpcServer = await RPCServer.createRPCServer({ container, logger }); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const readWriteStream: ReadableWritablePair = { + readable: stream, + writable: outputStream, + }; + + const duplexHandler: DuplexStreamHandler = + async function* (input, _container, _connectionInfo, _ctx) { + for await (const val of input) { + yield val; + } + }; + + rpcServer.registerDuplexStreamHandler(methodName, duplexHandler); + + type TestType = { + metadata: { + token: string; + }; + data: JSONValue; + }; + rpcServer.registerForwardMiddleware(() => { + let first = true; + return (input, short) => + input.pipeThrough( + new TransformStream< + JsonRpcRequest, + JsonRpcRequest + >({ + transform: (chunk, controller) => { + if (first && chunk.params?.metadata.token !== validToken) { + short({ + jsonrpc: '2.0', + id: null, + error: { + code: 1, + message: 'failure of somekind', + }, + }); + controller.error(new rpcErrors.ErrorRpcNoMessageError()); + } + first = false; + controller.enqueue(chunk); + }, + }), + ); + }); + rpcServer.handleStream(readWriteStream, {} as ConnectionInfo); + await outputResult; + await rpcServer.destroy(); + }, + { numRuns: 1 }, + ); // TODO: // - Test odd conditions for handlers, like extra messages where 1 is expected. diff --git a/tests/RPC/utils.test.ts b/tests/RPC/utils.test.ts index bd737b505..4f8f12206 100644 --- a/tests/RPC/utils.test.ts +++ b/tests/RPC/utils.test.ts @@ -105,6 +105,37 @@ describe('utils tests', () => { { numRuns: 1000 }, ); + testProp( + 'merging transformation stream', + [fc.array(fc.integer()), fc.array(fc.integer())], + async (set1, set2) => { + const [outputResult, outputWriterStream] = + rpcTestUtils.streamToArray(); + const { controllerP, controllerTransformStream } = + rpcUtils.controllerTransformationFactory(); + void controllerTransformStream.readable + .pipeTo(outputWriterStream) + .catch(() => {}); + const writer = controllerTransformStream.writable.getWriter(); + const controller = await controllerP; + const expectedResult: Array = []; + for (let i = 0; i < Math.max(set1.length, set2.length); i++) { + if (set1[i] != null) { + await writer.write(set1[i]); + expectedResult.push(set1[i]); + } + if (set2[i] != null) { + controller.enqueue(set2[i]); + expectedResult.push(set2[i]); + } + } + await writer.close(); + + expect(await outputResult).toStrictEqual(expectedResult); + }, + { numRuns: 1000 }, + ); + // TODO: // - Test for badly structured data }); From 586c990b11af71370d2d2a10e391203e6b3bf0f2 Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Fri, 27 Jan 2023 17:36:53 +1100 Subject: [PATCH 23/44] feat: client handler usage examples Related #500 Related #501 [ci skip] --- src/RPC/RPCServer.ts | 6 +- src/clientRPC/handlers/agentStatus.ts | 49 +++++++++ src/clientRPC/handlers/agentUnlock.ts | 24 ++++ src/clientRPC/utils.ts | 58 ++++++++++ tests/clientRPC/handlers/agentStatus.test.ts | 97 ++++++++++++++++ tests/clientRPC/handlers/agentUnlock.test.ts | 110 +++++++++++++++++++ 6 files changed, 339 insertions(+), 5 deletions(-) create mode 100644 src/clientRPC/handlers/agentStatus.ts create mode 100644 src/clientRPC/handlers/agentUnlock.ts create mode 100644 src/clientRPC/utils.ts create mode 100644 tests/clientRPC/handlers/agentStatus.test.ts create mode 100644 tests/clientRPC/handlers/agentUnlock.test.ts diff --git a/src/RPC/RPCServer.ts b/src/RPC/RPCServer.ts index 7056a823a..000dba7a1 100644 --- a/src/RPC/RPCServer.ts +++ b/src/RPC/RPCServer.ts @@ -13,11 +13,7 @@ import type { ReadableWritablePair } from 'stream/web'; import type { JSONValue, POJO } from '../types'; import type { ConnectionInfo } from '../network/types'; import type { RPCErrorEvent } from './utils'; -import type { - MiddlewareFactory, - MiddlewareShort, - Middleware, -} from 'tokens/types'; +import type { MiddlewareFactory, MiddlewareShort, Middleware } from './types'; import { ReadableStream } from 'stream/web'; import { CreateDestroy, ready } from '@matrixai/async-init/dist/CreateDestroy'; import Logger from '@matrixai/logger'; diff --git a/src/clientRPC/handlers/agentStatus.ts b/src/clientRPC/handlers/agentStatus.ts new file mode 100644 index 000000000..9e4855f0e --- /dev/null +++ b/src/clientRPC/handlers/agentStatus.ts @@ -0,0 +1,49 @@ +import type { UnaryHandler } from '../../RPC/types'; +import type KeyRing from '../../keys/KeyRing'; +import type CertManager from '../../keys/CertManager'; +import type Logger from '@matrixai/logger'; +import type { NodeIdEncoded } from '../../ids'; +import type RPCClient from '../../RPC/RPCClient'; +import type { POJO } from '../../types'; +import * as nodesUtils from '../../nodes/utils'; +import * as keysUtils from '../../keys/utils'; + +type StatusResult = { + pid: number; + nodeId: NodeIdEncoded; + publicJwk: string; +}; +const agentStatusName = 'agentStatus'; +const agentStatusHandler: UnaryHandler = async ( + input, + container: { + keyRing: KeyRing; + certManager: CertManager; + logger: Logger; + }, + _connectionInfo, + _ctx, +) => { + return { + pid: process.pid, + nodeId: nodesUtils.encodeNodeId(container.keyRing.getNodeId()), + publicJwk: JSON.stringify( + keysUtils.publicKeyToJWK(container.keyRing.keyPair.publicKey), + ), + }; +}; + +const agentStatusCaller = async (metadata: POJO, rpcClient: RPCClient) => { + const result = await rpcClient.unaryCaller( + agentStatusName, + null, + metadata, + ); + return { + pid: result.pid, + nodeId: nodesUtils.decodeNodeId(result.nodeId), + publicJwk: result.publicJwk, + }; +}; + +export { agentStatusName, agentStatusHandler, agentStatusCaller }; diff --git a/src/clientRPC/handlers/agentUnlock.ts b/src/clientRPC/handlers/agentUnlock.ts new file mode 100644 index 000000000..3eadc2a8e --- /dev/null +++ b/src/clientRPC/handlers/agentUnlock.ts @@ -0,0 +1,24 @@ +import type { UnaryHandler } from '../../RPC/types'; +import type Logger from '@matrixai/logger'; +import type RPCClient from '../../RPC/RPCClient'; +import type { POJO } from '../../types'; + +const agentUnlockName = 'agentStatus'; +const agentUnlockHandler: UnaryHandler = async ( + _input, + _container: { + logger: Logger; + }, + _connectionInfo, + _ctx, +) => { + // This is a NOP handler, + // authentication and unlocking is handled via middleware + return null; +}; + +const agentUnlockCaller = async (metadata: POJO, rpcClient: RPCClient) => { + await rpcClient.unaryCaller(agentUnlockName, null, metadata); +}; + +export { agentUnlockName, agentUnlockHandler, agentUnlockCaller }; diff --git a/src/clientRPC/utils.ts b/src/clientRPC/utils.ts new file mode 100644 index 000000000..ca96647e7 --- /dev/null +++ b/src/clientRPC/utils.ts @@ -0,0 +1,58 @@ +import type { SessionToken } from '../sessions/types'; +import type KeyRing from '../keys/KeyRing'; +import type SessionManager from '../sessions/SessionManager'; +import type { Authenticate } from '../client/types'; +import * as grpc from '@grpc/grpc-js'; +import * as clientErrors from '../client/errors'; + +/** + * Encodes an Authorization header from session token + * Assumes token is already encoded + * Will mutate metadata if it is passed in + */ +function encodeAuthFromSession( + token: SessionToken, + metadata: grpc.Metadata = new grpc.Metadata(), +): grpc.Metadata { + metadata.set('Authorization', `Bearer ${token}`); + return metadata; +} + +function authenticator( + sessionManager: SessionManager, + keyRing: KeyRing, +): Authenticate { + return async ( + forwardMetadata: grpc.Metadata, + reverseMetadata: grpc.Metadata = new grpc.Metadata(), + ) => { + const auth = forwardMetadata.get('Authorization')[0] as string | undefined; + if (auth == null) { + throw new clientErrors.ErrorClientAuthMissing(); + } + if (auth.startsWith('Bearer ')) { + const token = auth.substring(7) as SessionToken; + if (!(await sessionManager.verifyToken(token))) { + throw new clientErrors.ErrorClientAuthDenied(); + } + } else if (auth.startsWith('Basic ')) { + const encoded = auth.substring(6); + const decoded = Buffer.from(encoded, 'base64').toString('utf-8'); + const match = decoded.match(/:(.*)/); + if (match == null) { + throw new clientErrors.ErrorClientAuthFormat(); + } + const password = match[1]; + if (!(await keyRing.checkPassword(password))) { + throw new clientErrors.ErrorClientAuthDenied(); + } + } else { + throw new clientErrors.ErrorClientAuthMissing(); + } + const token = await sessionManager.createToken(); + encodeAuthFromSession(token, reverseMetadata); + return reverseMetadata; + }; +} + +export { authenticator }; diff --git a/tests/clientRPC/handlers/agentStatus.test.ts b/tests/clientRPC/handlers/agentStatus.test.ts new file mode 100644 index 000000000..b64d33bf1 --- /dev/null +++ b/tests/clientRPC/handlers/agentStatus.test.ts @@ -0,0 +1,97 @@ +import type { ConnectionInfo } from '@/network/types'; +import fs from 'fs'; +import path from 'path'; +import os from 'os'; +import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; +import { DB } from '@matrixai/db'; +import KeyRing from '@/keys/KeyRing'; +import * as keysUtils from '@/keys/utils'; +import RPCServer from '@/RPC/RPCServer'; +import TaskManager from '@/tasks/TaskManager'; +import CertManager from '@/keys/CertManager'; +import { + agentStatusName, + agentStatusHandler, + agentStatusCaller, +} from '@/clientRPC/handlers/agentStatus'; +import RPCClient from '@/RPC/RPCClient'; +import * as rpcTestUtils from '../../RPC/utils'; + +describe('agentStatus', () => { + const logger = new Logger('agentStatus test', LogLevel.WARN, [ + new StreamHandler(), + ]); + const password = 'helloworld'; + let dataDir: string; + let db: DB; + let keyRing: KeyRing; + let taskManager: TaskManager; + let certManager: CertManager; + + beforeEach(async () => { + dataDir = await fs.promises.mkdtemp( + path.join(os.tmpdir(), 'polykey-test-'), + ); + const keysPath = path.join(dataDir, 'keys'); + const dbPath = path.join(dataDir, 'db'); + db = await DB.createDB({ + dbPath, + logger, + }); + keyRing = await KeyRing.createKeyRing({ + password, + keysPath, + logger, + passwordOpsLimit: keysUtils.passwordOpsLimits.min, + passwordMemLimit: keysUtils.passwordMemLimits.min, + strictMemoryLock: false, + }); + taskManager = await TaskManager.createTaskManager({ db, logger }); + certManager = await CertManager.createCertManager({ + db, + keyRing, + taskManager, + logger, + }); + }); + afterEach(async () => { + await certManager.stop(); + await taskManager.stop(); + await keyRing.stop(); + await db.stop(); + await fs.promises.rm(dataDir, { + force: true, + recursive: true, + }); + }); + test('get status', async () => { + // Setup + const rpcServer = await RPCServer.createRPCServer({ + container: { + // KeyRing, + // certManager, + logger, + }, + logger, + }); + rpcServer.registerUnaryHandler(agentStatusName, agentStatusHandler); + const rpcClient = await RPCClient.createRPCClient({ + streamPairCreateCallback: async () => { + const { clientPair, serverPair } = rpcTestUtils.createTapPairs(); + rpcServer.handleStream(serverPair, {} as ConnectionInfo); + return clientPair; + }, + logger, + }); + + // Doing the test + const result = await agentStatusCaller({}, rpcClient); + expect(result).toStrictEqual({ + pid: process.pid, + nodeId: keyRing.getNodeId(), + publicJwk: JSON.stringify( + keysUtils.publicKeyToJWK(keyRing.keyPair.publicKey), + ), + }); + }); +}); diff --git a/tests/clientRPC/handlers/agentUnlock.test.ts b/tests/clientRPC/handlers/agentUnlock.test.ts new file mode 100644 index 000000000..b6194d864 --- /dev/null +++ b/tests/clientRPC/handlers/agentUnlock.test.ts @@ -0,0 +1,110 @@ +import type { ConnectionInfo } from '@/network/types'; +import fs from 'fs'; +import path from 'path'; +import os from 'os'; +import { TransformStream } from 'stream/web'; +import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; +import { DB } from '@matrixai/db'; +import KeyRing from '@/keys/KeyRing'; +import * as keysUtils from '@/keys/utils'; +import RPCServer from '@/RPC/RPCServer'; +import TaskManager from '@/tasks/TaskManager'; +import CertManager from '@/keys/CertManager'; +import { + agentUnlockName, + agentUnlockHandler, + agentUnlockCaller, +} from '@/clientRPC/handlers/agentUnlock'; +import RPCClient from '@/RPC/RPCClient'; +import * as rpcTestUtils from '../../RPC/utils'; + +describe('agentStatus', () => { + const logger = new Logger('agentStatus test', LogLevel.WARN, [ + new StreamHandler(), + ]); + const password = 'helloworld'; + let dataDir: string; + let db: DB; + let keyRing: KeyRing; + let taskManager: TaskManager; + let certManager: CertManager; + + beforeEach(async () => { + dataDir = await fs.promises.mkdtemp( + path.join(os.tmpdir(), 'polykey-test-'), + ); + const keysPath = path.join(dataDir, 'keys'); + const dbPath = path.join(dataDir, 'db'); + db = await DB.createDB({ + dbPath, + logger, + }); + keyRing = await KeyRing.createKeyRing({ + password, + keysPath, + logger, + passwordOpsLimit: keysUtils.passwordOpsLimits.min, + passwordMemLimit: keysUtils.passwordMemLimits.min, + strictMemoryLock: false, + }); + taskManager = await TaskManager.createTaskManager({ db, logger }); + certManager = await CertManager.createCertManager({ + db, + keyRing, + taskManager, + logger, + }); + }); + afterEach(async () => { + await certManager.stop(); + await taskManager.stop(); + await keyRing.stop(); + await db.stop(); + await fs.promises.rm(dataDir, { + force: true, + recursive: true, + }); + }); + test('get status', async () => { + // Setup + const rpcServer = await RPCServer.createRPCServer({ + container: { + // KeyRing, + // certManager, + logger, + }, + logger, + }); + rpcServer.registerUnaryHandler(agentUnlockName, agentUnlockHandler); + rpcServer.registerForwardMiddleware(() => { + return (input) => { + // This middleware needs to check the first message for the token + return input.pipeThrough( + new TransformStream({ + transform: (chunk, controller) => { + controller.enqueue(chunk); + }, + }), + ); + }; + }); + const rpcClient = await RPCClient.createRPCClient({ + streamPairCreateCallback: async () => { + const { clientPair, serverPair } = rpcTestUtils.createTapPairs(); + rpcServer.handleStream(serverPair, {} as ConnectionInfo); + return clientPair; + }, + logger, + }); + + // Doing the test + const result = await agentUnlockCaller({}, rpcClient); + expect(result).toStrictEqual({ + pid: process.pid, + nodeId: keyRing.getNodeId(), + publicJwk: JSON.stringify( + keysUtils.publicKeyToJWK(keyRing.keyPair.publicKey), + ), + }); + }); +}); From fb198ff4fe1f0821421122a9b22dde9af0cc32c5 Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Fri, 27 Jan 2023 19:24:42 +1100 Subject: [PATCH 24/44] fix: refactoring middleware Related #500 Related #502 [ci skip] --- src/RPC/RPCClient.ts | 74 ++++++++++------------- src/RPC/RPCServer.ts | 89 +++++++++++---------------- src/RPC/types.ts | 13 ++-- tests/RPC/RPCClient.test.ts | 62 +++++++++---------- tests/RPC/RPCServer.test.ts | 116 +++++++++++++++++++----------------- 5 files changed, 164 insertions(+), 190 deletions(-) diff --git a/src/RPC/RPCClient.ts b/src/RPC/RPCClient.ts index 41109f76a..8c348ac1a 100644 --- a/src/RPC/RPCClient.ts +++ b/src/RPC/RPCClient.ts @@ -5,7 +5,6 @@ import type { JsonRpcRequest, JsonRpcResponse, MiddlewareFactory, - Middleware, } from './types'; import { CreateDestroy, ready } from '@matrixai/async-init/dist/CreateDestroy'; import Logger from '@matrixai/logger'; @@ -55,34 +54,37 @@ class RPCClient { method: string, _metadata: POJO, ): Promise> { - const streamPair = await this.streamPairCreateCallback(); - let reverseMiddlewareStream = streamPair.readable.pipeThrough( - new rpcUtils.JsonToJsonMessageStream(rpcUtils.parseJsonRpcResponse), - ); - for (const middleWare of this.reverseMiddleware) { - const middle = middleWare(); - reverseMiddlewareStream = middle(reverseMiddlewareStream); - } - const outputStream = reverseMiddlewareStream.pipeThrough( - new rpcUtils.ClientOutputTransformerStream(), - ); - const inputMessageTransformer = + // Creating caller side transforms + const outputMessageTransforStream = + new rpcUtils.ClientOutputTransformerStream(); + const inputMessageTransformStream = new rpcUtils.ClientInputTransformerStream(method); - let forwardMiddlewareStream = inputMessageTransformer.readable; - for (const middleware of this.forwardMiddleWare) { - const middle = middleware(); - forwardMiddlewareStream = middle(forwardMiddlewareStream); + let reverseStream = outputMessageTransforStream.writable; + let forwardStream = inputMessageTransformStream.readable; + // Setting up middleware chains + for (const middlewareFactory of this.middleware) { + const middleware = middlewareFactory(); + forwardStream = forwardStream.pipeThrough(middleware.forward); + void middleware.reverse.readable.pipeTo(reverseStream).catch(() => {}); + reverseStream = middleware.reverse.writable; } - void forwardMiddlewareStream + // Hooking up agnostic stream side + const streamPair = await this.streamPairCreateCallback(); + void streamPair.readable + .pipeThrough( + new rpcUtils.JsonToJsonMessageStream(rpcUtils.parseJsonRpcResponse), + ) + .pipeTo(reverseStream) + .catch(() => {}); + void forwardStream .pipeThrough(new rpcUtils.JsonMessageToJsonStream()) .pipeTo(streamPair.writable) .catch(() => {}); - const inputStream = inputMessageTransformer.writable; // Returning interface return { - readable: outputStream, - writable: inputStream, + readable: outputMessageTransforStream.readable, + writable: inputMessageTransformStream.writable, }; } @@ -205,37 +207,23 @@ class RPCClient { return callerInterface.output; } - protected forwardMiddleWare: Array< - MiddlewareFactory>> - > = []; - protected reverseMiddleware: Array< - MiddlewareFactory>> + protected middleware: Array< + MiddlewareFactory, JsonRpcResponse> > = []; @ready(new rpcErrors.ErrorRpcDestroyed()) - public registerForwardMiddleware( - middlewareFactory: MiddlewareFactory>>, - ) { - this.forwardMiddleWare.push(middlewareFactory); - } - - @ready(new rpcErrors.ErrorRpcDestroyed()) - public clearForwardMiddleware() { - this.reverseMiddleware = []; - } - - @ready(new rpcErrors.ErrorRpcDestroyed()) - public registerReverseMiddleware( + public registerMiddleware( middlewareFactory: MiddlewareFactory< - Middleware> + JsonRpcRequest, + JsonRpcResponse >, ) { - this.reverseMiddleware.push(middlewareFactory); + this.middleware.push(middlewareFactory); } @ready(new rpcErrors.ErrorRpcDestroyed()) - public clearReverseMiddleware() { - this.reverseMiddleware = []; + public clearMiddleware() { + this.middleware = []; } } diff --git a/src/RPC/RPCServer.ts b/src/RPC/RPCServer.ts index 000dba7a1..7d360d280 100644 --- a/src/RPC/RPCServer.ts +++ b/src/RPC/RPCServer.ts @@ -13,7 +13,7 @@ import type { ReadableWritablePair } from 'stream/web'; import type { JSONValue, POJO } from '../types'; import type { ConnectionInfo } from '../network/types'; import type { RPCErrorEvent } from './utils'; -import type { MiddlewareFactory, MiddlewareShort, Middleware } from './types'; +import type { MiddlewareFactory } from './types'; import { ReadableStream } from 'stream/web'; import { CreateDestroy, ready } from '@matrixai/async-init/dist/CreateDestroy'; import Logger from '@matrixai/logger'; @@ -154,22 +154,25 @@ class RPCServer { void handlerProm .finally(() => this.activeStreams.delete(handlerProm)) .catch(() => {}); - // Setting up forward middleware - let middlewareStream = streamPair.readable.pipeThrough( + // Setting up middleware + let forwardStream = streamPair.readable.pipeThrough( new rpcUtils.JsonToJsonMessageStream(rpcUtils.parseJsonRpcRequest), ); - const shortMessageQueue: Array = []; - for (const forwardMiddleWareFactory of this.forwardMiddleWare) { - const middleware = forwardMiddleWareFactory(); - middlewareStream = middleware( - middlewareStream, - (value: JsonRpcResponse) => shortMessageQueue.push(value), - ); + const outputTransformStream = new rpcUtils.JsonMessageToJsonStream(); + void outputTransformStream.readable + .pipeTo(streamPair.writable) + .catch(() => {}); + let reverseStream = outputTransformStream.writable; + for (const middlewareFactory of this.middleware) { + const middleware = middlewareFactory(); + forwardStream = forwardStream.pipeThrough(middleware.forward); + void middleware.reverse.readable.pipeTo(reverseStream).catch(() => {}); + reverseStream = middleware.reverse.writable; } // While ReadableStream can be converted to AsyncIterable, we want it as // a generator. const inputGen = async function* () { - for await (const dataMessage of middlewareStream) { + for await (const dataMessage of forwardStream) { yield dataMessage; } }; @@ -186,9 +189,8 @@ class RPCServer { const input = inputGen(); if (ctx.signal.aborted) throw ctx.signal.reason; const leadingMetadataMessage = await input.next(); - if (leadingMetadataMessage.done === true) { - throw new rpcErrors.ErrorRpcProtocal('Stream ended before response'); - } + // If the stream ends early then we just stop processing + if (leadingMetadataMessage.done === true) return; const method = leadingMetadataMessage.value.method; const initialParams = leadingMetadataMessage.value.params; const dataGen = async function* () { @@ -223,14 +225,18 @@ class RPCServer { const outputGenerator = outputGen(); - let reverseMiddlewareStream = new ReadableStream< + const reverseMiddlewareStream = new ReadableStream< JsonRpcResponse >({ pull: async (controller) => { try { const { value, done } = await outputGenerator.next(); if (done) { - controller.close(); + try { + controller.close(); + } catch { + // Ignore already closed error + } resolve(); return; } @@ -259,7 +265,11 @@ class RPCServer { }), ); } - controller.close(); + try { + controller.close(); + } catch { + // Ignore already closed error + } resolve(); } }, @@ -267,16 +277,7 @@ class RPCServer { await outputGenerator.throw(reason); }, }); - // Setting up reverse middleware - for (const reverseMiddleWareFactory of this.reverseMiddleware) { - const middleware = reverseMiddleWareFactory(); - reverseMiddlewareStream = middleware(reverseMiddlewareStream); - } - reverseMiddlewareStream - .pipeThrough(new rpcUtils.QueueMergingTransformStream(shortMessageQueue)) - .pipeThrough(new rpcUtils.JsonMessageToJsonStream()) - .pipeTo(streamPair.writable) - .catch(() => {}); + void reverseMiddlewareStream.pipeTo(reverseStream).catch(() => {}); } @ready(new rpcErrors.ErrorRpcDestroyed()) @@ -297,41 +298,23 @@ class RPCServer { this.events.removeEventListener(type, callback, options); } - protected forwardMiddleWare: Array< - MiddlewareFactory< - MiddlewareShort, JsonRpcResponse> - > - > = []; - protected reverseMiddleware: Array< - MiddlewareFactory>> + protected middleware: Array< + MiddlewareFactory, JsonRpcResponse> > = []; @ready(new rpcErrors.ErrorRpcDestroyed()) - public registerForwardMiddleware( + public registerMiddleware( middlewareFactory: MiddlewareFactory< - MiddlewareShort, JsonRpcResponse> - >, - ) { - this.forwardMiddleWare.push(middlewareFactory); - } - - @ready(new rpcErrors.ErrorRpcDestroyed()) - public clearForwardMiddleware() { - this.reverseMiddleware = []; - } - - @ready(new rpcErrors.ErrorRpcDestroyed()) - public registerReverseMiddleware( - middlewareFactory: MiddlewareFactory< - Middleware> + JsonRpcRequest, + JsonRpcResponse >, ) { - this.reverseMiddleware.push(middlewareFactory); + this.middleware.push(middlewareFactory); } @ready(new rpcErrors.ErrorRpcDestroyed()) - public clearReverseMiddleware() { - this.reverseMiddleware = []; + public clearMiddleware() { + this.middleware = []; } } diff --git a/src/RPC/types.ts b/src/RPC/types.ts index a233da2df..b1554469d 100644 --- a/src/RPC/types.ts +++ b/src/RPC/types.ts @@ -2,7 +2,6 @@ import type { JSONValue, POJO } from '../types'; import type { ConnectionInfo } from '../network/types'; import type { ContextCancellable } from '../contexts/types'; import type { ReadableWritablePair } from 'stream/web'; -import type { ReadableStream } from 'stream/web'; /** * This is the JSON RPC request object. this is the generic message type used for the RPC. @@ -128,12 +127,10 @@ type StreamPairCreateCallback = () => Promise< ReadableWritablePair >; -type MiddlewareShort = ( - input: ReadableStream, - short: (value: K) => void, -) => ReadableStream; -type Middleware = (input: ReadableStream) => ReadableStream; -type MiddlewareFactory = () => T; +type MiddlewareFactory = () => { + forward: ReadableWritablePair; + reverse: ReadableWritablePair; +}; export type { JsonRpcRequestMessage, @@ -149,7 +146,5 @@ export type { ClientStreamHandler, UnaryHandler, StreamPairCreateCallback, - MiddlewareShort, - Middleware, MiddlewareFactory, }; diff --git a/tests/RPC/RPCClient.test.ts b/tests/RPC/RPCClient.test.ts index 9ce0d289e..accc20ed5 100644 --- a/tests/RPC/RPCClient.test.ts +++ b/tests/RPC/RPCClient.test.ts @@ -341,21 +341,21 @@ describe(`${RPCClient.name}`, () => { logger, }); - rpcClient.registerForwardMiddleware(() => { - return (input) => - input.pipeThrough( - new TransformStream< - JsonRpcRequest, - JsonRpcRequest - >({ - transform: (chunk, controller) => { - controller.enqueue({ - ...chunk, - params: 'one', - }); - }, - }), - ); + rpcClient.registerMiddleware(() => { + return { + forward: new TransformStream< + JsonRpcRequest, + JsonRpcRequest + >({ + transform: (chunk, controller) => { + controller.enqueue({ + ...chunk, + params: 'one', + }); + }, + }), + reverse: new TransformStream(), + }; }); const callerInterface = await rpcClient.duplexStreamCaller< JSONValue, @@ -391,7 +391,7 @@ describe(`${RPCClient.name}`, () => { await rpcClient.destroy(); }, ); - testProp.only( + testProp( 'generic duplex caller with reverse Middleware', [specificMessageArb], async (messages) => { @@ -407,21 +407,21 @@ describe(`${RPCClient.name}`, () => { logger, }); - rpcClient.registerReverseMiddleware(() => { - return (input) => - input.pipeThrough( - new TransformStream< - JsonRpcResponse, - JsonRpcResponse - >({ - transform: (chunk, controller) => { - controller.enqueue({ - ...chunk, - result: 'one', - }); - }, - }), - ); + rpcClient.registerMiddleware(() => { + return { + forward: new TransformStream(), + reverse: new TransformStream< + JsonRpcResponse, + JsonRpcResponse + >({ + transform: (chunk, controller) => { + controller.enqueue({ + ...chunk, + result: 'one', + }); + }, + }), + }; }); const callerInterface = await rpcClient.duplexStreamCaller< JSONValue, diff --git a/tests/RPC/RPCServer.test.ts b/tests/RPC/RPCServer.test.ts index 40b9881d8..0af80fb19 100644 --- a/tests/RPC/RPCServer.test.ts +++ b/tests/RPC/RPCServer.test.ts @@ -4,6 +4,7 @@ import type { JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, + JsonRpcResponseError, ServerStreamHandler, UnaryHandler, } from '@/RPC/types'; @@ -391,19 +392,16 @@ describe(`${RPCServer.name}`, () => { }; rpcServer.registerDuplexStreamHandler(methodName, duplexHandler); - rpcServer.registerForwardMiddleware(() => { - return (input) => - input.pipeThrough( - new TransformStream< - JsonRpcRequest, - JsonRpcRequest - >({ - transform: (chunk, controller) => { - chunk.params = 1; - controller.enqueue(chunk); - }, - }), - ); + rpcServer.registerMiddleware(() => { + return { + forward: new TransformStream({ + transform: (chunk, controller) => { + chunk.params = 1; + controller.enqueue(chunk); + }, + }), + reverse: new TransformStream(), + }; }); rpcServer.handleStream(readWriteStream, {} as ConnectionInfo); const out = await outputResult; @@ -436,21 +434,16 @@ describe(`${RPCServer.name}`, () => { }; rpcServer.registerDuplexStreamHandler(methodName, duplexHandler); - rpcServer.registerReverseMiddleware(() => { - return (input) => - input.pipeThrough( - new TransformStream< - JsonRpcResponse, - JsonRpcResponse - >({ - transform: (chunk, controller) => { - if ('result' in chunk) { - chunk.result = 1; - } - controller.enqueue(chunk); - }, - }), - ); + rpcServer.registerMiddleware(() => { + return { + forward: new TransformStream(), + reverse: new TransformStream({ + transform: (chunk, controller) => { + if ('result' in chunk) chunk.result = 1; + controller.enqueue(chunk); + }, + }), + }; }); rpcServer.handleStream(readWriteStream, {} as ConnectionInfo); const out = await outputResult; @@ -504,37 +497,52 @@ describe(`${RPCServer.name}`, () => { }; data: JSONValue; }; - rpcServer.registerForwardMiddleware(() => { + const failureMessage: JsonRpcResponseError = { + jsonrpc: '2.0', + id: null, + error: { + code: 1, + message: 'failure of somekind', + }, + }; + rpcServer.registerMiddleware(() => { let first = true; - return (input, short) => - input.pipeThrough( - new TransformStream< - JsonRpcRequest, - JsonRpcRequest - >({ - transform: (chunk, controller) => { - if (first && chunk.params?.metadata.token !== validToken) { - short({ - jsonrpc: '2.0', - id: null, - error: { - code: 1, - message: 'failure of somekind', - }, - }); - controller.error(new rpcErrors.ErrorRpcNoMessageError()); - } - first = false; - controller.enqueue(chunk); - }, - }), - ); + let reverseController: TransformStreamDefaultController< + JsonRpcResponse + >; + return { + forward: new TransformStream< + JsonRpcRequest, + JsonRpcRequest + >({ + transform: (chunk, controller) => { + if (first && chunk.params?.metadata.token !== validToken) { + reverseController.enqueue(failureMessage); + // Closing streams early + controller.terminate(); + reverseController.terminate(); + } + first = false; + controller.enqueue(chunk); + }, + }), + reverse: new TransformStream({ + start: (controller) => { + // Kidnapping reverse controller + reverseController = controller; + }, + transform: (chunk, controller) => { + controller.enqueue(chunk); + }, + }), + }; }); rpcServer.handleStream(readWriteStream, {} as ConnectionInfo); - await outputResult; + expect((await outputResult).toString()).toEqual( + JSON.stringify(failureMessage), + ); await rpcServer.destroy(); }, - { numRuns: 1 }, ); // TODO: From 29e61a95926967ff7653e52b32199195afa551c1 Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Mon, 30 Jan 2023 19:09:43 +1100 Subject: [PATCH 25/44] feat: agentUnlock example Related #500 Related #501 Related #502 [ci skip] --- src/RPC/RPCClient.ts | 35 +--- src/RPC/RPCServer.ts | 5 +- src/clientRPC/handlers/agentUnlock.ts | 29 ++- src/clientRPC/types.ts | 10 + src/clientRPC/utils.ts | 199 ++++++++++++++----- tests/clientRPC/handlers/agentStatus.test.ts | 4 +- tests/clientRPC/handlers/agentUnlock.test.ts | 60 ++++-- 7 files changed, 235 insertions(+), 107 deletions(-) create mode 100644 src/clientRPC/types.ts diff --git a/src/RPC/RPCClient.ts b/src/RPC/RPCClient.ts index 8c348ac1a..0931a8020 100644 --- a/src/RPC/RPCClient.ts +++ b/src/RPC/RPCClient.ts @@ -1,5 +1,5 @@ import type { StreamPairCreateCallback } from './types'; -import type { JSONValue, POJO } from 'types'; +import type { JSONValue } from 'types'; import type { ReadableWritablePair } from 'stream/web'; import type { JsonRpcRequest, @@ -52,7 +52,6 @@ class RPCClient { @ready(new rpcErrors.ErrorRpcDestroyed()) public async duplexStreamCaller( method: string, - _metadata: POJO, ): Promise> { // Creating caller side transforms const outputMessageTransforStream = @@ -92,12 +91,8 @@ class RPCClient { public async serverStreamCaller( method: string, parameters: I, - metadata: POJO, ) { - const callerInterface = await this.duplexStreamCaller( - method, - metadata, - ); + const callerInterface = await this.duplexStreamCaller(method); const writer = callerInterface.writable.getWriter(); await writer.write(parameters); await writer.close(); @@ -108,12 +103,8 @@ class RPCClient { @ready(new rpcErrors.ErrorRpcDestroyed()) public async clientStreamCaller( method: string, - metadata: POJO, ) { - const callerInterface = await this.duplexStreamCaller( - method, - metadata, - ); + const callerInterface = await this.duplexStreamCaller(method); const reader = callerInterface.readable.getReader(); const output = reader.read().then(({ value, done }) => { if (done) { @@ -131,12 +122,8 @@ class RPCClient { public async unaryCaller( method: string, parameters: I, - metadata: POJO, ): Promise { - const callerInterface = await this.duplexStreamCaller( - method, - metadata, - ); + const callerInterface = await this.duplexStreamCaller(method); const reader = callerInterface.readable.getReader(); const writer = callerInterface.writable.getWriter(); await writer.write(parameters); @@ -153,12 +140,8 @@ class RPCClient { public async withDuplexCaller( method: string, f: (output: AsyncGenerator) => AsyncGenerator, - metadata: POJO, ): Promise { - const callerInterface = await this.duplexStreamCaller( - method, - metadata, - ); + const callerInterface = await this.duplexStreamCaller(method); const outputGenerator = async function* () { for await (const value of callerInterface.readable) { yield value; @@ -176,12 +159,10 @@ class RPCClient { method: string, parameters: I, f: (output: AsyncGenerator) => Promise, - metadata: POJO, ) { const callerInterface = await this.serverStreamCaller( method, parameters, - metadata, ); const outputGenerator = async function* () { yield* callerInterface; @@ -193,12 +174,8 @@ class RPCClient { public async withClientCaller( method: string, f: () => AsyncGenerator, - metadata: POJO, ): Promise { - const callerInterface = await this.clientStreamCaller( - method, - metadata, - ); + const callerInterface = await this.clientStreamCaller(method); const writer = callerInterface.writable.getWriter(); for await (const value of f()) { await writer.write(value); diff --git a/src/RPC/RPCServer.ts b/src/RPC/RPCServer.ts index 7d360d280..e6166fba8 100644 --- a/src/RPC/RPCServer.ts +++ b/src/RPC/RPCServer.ts @@ -20,6 +20,7 @@ import Logger from '@matrixai/logger'; import { PromiseCancellable } from '@matrixai/async-cancellable'; import * as rpcUtils from './utils'; import * as rpcErrors from './errors'; +import { sysexits } from '../errors'; interface RPCServer extends CreateDestroy {} @CreateDestroy() @@ -245,8 +246,8 @@ class RPCServer { if (rpcUtils.isReturnableError(e)) { // We want to convert this error to an error message and pass it along const rpcError: JsonRpcError = { - code: e.exitCode, - message: e.description, + code: e.exitCode ?? sysexits.UNKNOWN, + message: e.description ?? '', data: rpcUtils.fromError(e), }; const rpcErrorMessage: JsonRpcResponseError = { diff --git a/src/clientRPC/handlers/agentUnlock.ts b/src/clientRPC/handlers/agentUnlock.ts index 3eadc2a8e..e1e77ad30 100644 --- a/src/clientRPC/handlers/agentUnlock.ts +++ b/src/clientRPC/handlers/agentUnlock.ts @@ -1,10 +1,14 @@ import type { UnaryHandler } from '../../RPC/types'; import type Logger from '@matrixai/logger'; import type RPCClient from '../../RPC/RPCClient'; -import type { POJO } from '../../types'; +import type { JSONValue } from '../../types'; +import type { ClientDataAndMetadata } from '../types'; const agentUnlockName = 'agentStatus'; -const agentUnlockHandler: UnaryHandler = async ( +const agentUnlockHandler: UnaryHandler< + ClientDataAndMetadata, + ClientDataAndMetadata +> = async ( _input, _container: { logger: Logger; @@ -13,12 +17,25 @@ const agentUnlockHandler: UnaryHandler = async ( _ctx, ) => { // This is a NOP handler, - // authentication and unlocking is handled via middleware - return null; + // authentication and unlocking is handled via middleware. + // Failure to authenticate will be an error from the middleware layer. + return { + metadata: {}, + data: null, + }; }; -const agentUnlockCaller = async (metadata: POJO, rpcClient: RPCClient) => { - await rpcClient.unaryCaller(agentUnlockName, null, metadata); +const agentUnlockCaller = async ( + metadata: Record, + rpcClient: RPCClient, +) => { + return rpcClient.unaryCaller< + ClientDataAndMetadata, + ClientDataAndMetadata + >(agentUnlockName, { + metadata: metadata, + data: null, + }); }; export { agentUnlockName, agentUnlockHandler, agentUnlockCaller }; diff --git a/src/clientRPC/types.ts b/src/clientRPC/types.ts new file mode 100644 index 000000000..b570749f1 --- /dev/null +++ b/src/clientRPC/types.ts @@ -0,0 +1,10 @@ +import type { JSONValue } from '../types'; + +type ClientDataAndMetadata = { + metadata: JSONValue & { + Authorization?: string; + }; + data: T; +}; + +export type { ClientDataAndMetadata }; diff --git a/src/clientRPC/utils.ts b/src/clientRPC/utils.ts index ca96647e7..0aa45cd16 100644 --- a/src/clientRPC/utils.ts +++ b/src/clientRPC/utils.ts @@ -1,58 +1,165 @@ import type { SessionToken } from '../sessions/types'; import type KeyRing from '../keys/KeyRing'; import type SessionManager from '../sessions/SessionManager'; -import type { Authenticate } from '../client/types'; -import * as grpc from '@grpc/grpc-js'; +import type { Session } from '../sessions'; +import type { ClientDataAndMetadata } from './types'; +import type { JSONValue } from '../types'; +import type { + JsonRpcRequest, + JsonRpcResponse, + MiddlewareFactory, +} from '../RPC/types'; +import { TransformStream } from 'stream/web'; import * as clientErrors from '../client/errors'; +import * as utils from '../utils'; -/** - * Encodes an Authorization header from session token - * Assumes token is already encoded - * Will mutate metadata if it is passed in - */ -function encodeAuthFromSession( - token: SessionToken, - metadata: grpc.Metadata = new grpc.Metadata(), -): grpc.Metadata { - metadata.set('Authorization', `Bearer ${token}`); - return metadata; -} - -function authenticator( +async function authenticate( sessionManager: SessionManager, keyRing: KeyRing, -): Authenticate { - return async ( - forwardMetadata: grpc.Metadata, - reverseMetadata: grpc.Metadata = new grpc.Metadata(), - ) => { - const auth = forwardMetadata.get('Authorization')[0] as string | undefined; - if (auth == null) { - throw new clientErrors.ErrorClientAuthMissing(); + message: JsonRpcRequest>, +) { + if (message.params == null) throw new clientErrors.ErrorClientAuthMissing(); + if (message.params.metadata == null) { + throw new clientErrors.ErrorClientAuthMissing(); + } + const auth = message.params.metadata.Authorization; + if (auth == null) { + throw new clientErrors.ErrorClientAuthMissing(); + } + if (auth.startsWith('Bearer ')) { + const token = auth.substring(7) as SessionToken; + if (!(await sessionManager.verifyToken(token))) { + throw new clientErrors.ErrorClientAuthDenied(); } - if (auth.startsWith('Bearer ')) { - const token = auth.substring(7) as SessionToken; - if (!(await sessionManager.verifyToken(token))) { - throw new clientErrors.ErrorClientAuthDenied(); - } - } else if (auth.startsWith('Basic ')) { - const encoded = auth.substring(6); - const decoded = Buffer.from(encoded, 'base64').toString('utf-8'); - const match = decoded.match(/:(.*)/); - if (match == null) { - throw new clientErrors.ErrorClientAuthFormat(); - } - const password = match[1]; - if (!(await keyRing.checkPassword(password))) { - throw new clientErrors.ErrorClientAuthDenied(); - } - } else { - throw new clientErrors.ErrorClientAuthMissing(); + } else if (auth.startsWith('Basic ')) { + const encoded = auth.substring(6); + const decoded = Buffer.from(encoded, 'base64').toString('utf-8'); + const match = decoded.match(/:(.*)/); + if (match == null) { + throw new clientErrors.ErrorClientAuthFormat(); } - const token = await sessionManager.createToken(); - encodeAuthFromSession(token, reverseMetadata); - return reverseMetadata; + const password = match[1]; + if (!(await keyRing.checkPassword(password))) { + throw new clientErrors.ErrorClientAuthDenied(); + } + } else { + throw new clientErrors.ErrorClientAuthMissing(); + } + const token = await sessionManager.createToken(); + return `Bearer ${token}`; +} + +function decodeAuth(messageParams: ClientDataAndMetadata) { + const auth = messageParams.metadata.Authorization; + if (auth == null || !auth.startsWith('Bearer ')) { + return; + } + return auth.substring(7) as SessionToken; +} + +function encodeAuthFromPassword(password: string): string { + const encoded = Buffer.from(`:${password}`).toString('base64'); + return `Basic ${encoded}`; +} + +function authenticationMiddlewareServer( + sessionManager: SessionManager, + keyRing: KeyRing, +): MiddlewareFactory< + JsonRpcRequest>, + JsonRpcResponse> +> { + return () => { + let forwardFirst = true; + let reverseController; + let outgoingToken: string | null = null; + return { + forward: new TransformStream< + JsonRpcRequest>, + JsonRpcRequest> + >({ + transform: async (chunk, controller) => { + if (forwardFirst) { + try { + outgoingToken = await authenticate( + sessionManager, + keyRing, + chunk, + ); + } catch (e) { + controller.terminate(); + reverseController.terminate(); + return; + } + } + forwardFirst = false; + controller.enqueue(chunk); + }, + }), + reverse: new TransformStream({ + start: (controller) => { + reverseController = controller; + }, + transform: (chunk, controller) => { + // Add the outgoing metadata to the next message. + if (outgoingToken != null && 'result' in chunk) { + chunk.result.metadata.Authorization = outgoingToken; + outgoingToken = null; + } + controller.enqueue(chunk); + }, + }), + }; + }; +} + +function authenticationMiddlewareClient( + session: Session, +): MiddlewareFactory< + JsonRpcRequest>, + JsonRpcResponse> +> { + return () => { + let forwardFirst = true; + return { + forward: new TransformStream< + JsonRpcRequest>, + JsonRpcRequest> + >({ + transform: async (chunk, controller) => { + if (forwardFirst) { + if (chunk.params == null) utils.never(); + if (chunk.params.metadata.Authorization == null) { + const token = await session.readToken(); + if (token != null) { + chunk.params.metadata.Authorization = `Bearer ${token}`; + } + } + } + forwardFirst = false; + controller.enqueue(chunk); + }, + }), + reverse: new TransformStream< + JsonRpcResponse>, + JsonRpcResponse> + >({ + transform: async (chunk, controller) => { + controller.enqueue(chunk); + if (!('result' in chunk)) return; + const token = decodeAuth(chunk.result); + if (token == null) return; + await session.writeToken(token); + }, + }), + }; }; } -export { authenticator }; +export { + authenticate, + decodeAuth, + encodeAuthFromPassword, + authenticationMiddlewareServer, + authenticationMiddlewareClient, +}; diff --git a/tests/clientRPC/handlers/agentStatus.test.ts b/tests/clientRPC/handlers/agentStatus.test.ts index b64d33bf1..6f0de9e39 100644 --- a/tests/clientRPC/handlers/agentStatus.test.ts +++ b/tests/clientRPC/handlers/agentStatus.test.ts @@ -68,8 +68,8 @@ describe('agentStatus', () => { // Setup const rpcServer = await RPCServer.createRPCServer({ container: { - // KeyRing, - // certManager, + keyRing, + certManager, logger, }, logger, diff --git a/tests/clientRPC/handlers/agentUnlock.test.ts b/tests/clientRPC/handlers/agentUnlock.test.ts index b6194d864..9dacadc70 100644 --- a/tests/clientRPC/handlers/agentUnlock.test.ts +++ b/tests/clientRPC/handlers/agentUnlock.test.ts @@ -2,7 +2,6 @@ import type { ConnectionInfo } from '@/network/types'; import fs from 'fs'; import path from 'path'; import os from 'os'; -import { TransformStream } from 'stream/web'; import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; import { DB } from '@matrixai/db'; import KeyRing from '@/keys/KeyRing'; @@ -16,6 +15,8 @@ import { agentUnlockCaller, } from '@/clientRPC/handlers/agentUnlock'; import RPCClient from '@/RPC/RPCClient'; +import { Session, SessionManager } from '@/sessions'; +import * as abcUtils from '@/clientRPC/utils'; import * as rpcTestUtils from '../../RPC/utils'; describe('agentStatus', () => { @@ -28,6 +29,8 @@ describe('agentStatus', () => { let keyRing: KeyRing; let taskManager: TaskManager; let certManager: CertManager; + let session: Session; + let sessionManager: SessionManager; beforeEach(async () => { dataDir = await fs.promises.mkdtemp( @@ -35,6 +38,7 @@ describe('agentStatus', () => { ); const keysPath = path.join(dataDir, 'keys'); const dbPath = path.join(dataDir, 'db'); + const sessionPath = path.join(dataDir, 'session'); db = await DB.createDB({ dbPath, logger, @@ -54,6 +58,15 @@ describe('agentStatus', () => { taskManager, logger, }); + session = await Session.createSession({ + sessionTokenPath: sessionPath, + logger, + }); + sessionManager = await SessionManager.createSessionManager({ + db, + keyRing, + logger, + }); }); afterEach(async () => { await certManager.stop(); @@ -69,25 +82,14 @@ describe('agentStatus', () => { // Setup const rpcServer = await RPCServer.createRPCServer({ container: { - // KeyRing, - // certManager, logger, }, logger, }); rpcServer.registerUnaryHandler(agentUnlockName, agentUnlockHandler); - rpcServer.registerForwardMiddleware(() => { - return (input) => { - // This middleware needs to check the first message for the token - return input.pipeThrough( - new TransformStream({ - transform: (chunk, controller) => { - controller.enqueue(chunk); - }, - }), - ); - }; - }); + rpcServer.registerMiddleware( + abcUtils.authenticationMiddlewareServer(sessionManager, keyRing), + ); const rpcClient = await RPCClient.createRPCClient({ streamPairCreateCallback: async () => { const { clientPair, serverPair } = rpcTestUtils.createTapPairs(); @@ -96,15 +98,29 @@ describe('agentStatus', () => { }, logger, }); + rpcClient.registerMiddleware( + abcUtils.authenticationMiddlewareClient(session), + ); // Doing the test - const result = await agentUnlockCaller({}, rpcClient); - expect(result).toStrictEqual({ - pid: process.pid, - nodeId: keyRing.getNodeId(), - publicJwk: JSON.stringify( - keysUtils.publicKeyToJWK(keyRing.keyPair.publicKey), - ), + const result = await agentUnlockCaller( + { + Authorization: abcUtils.encodeAuthFromPassword(password), + }, + rpcClient, + ); + expect(result).toMatchObject({ + metadata: { + Authorization: expect.any(String), + }, + data: null, + }); + const result2 = await agentUnlockCaller({}, rpcClient); + expect(result2).toMatchObject({ + metadata: { + Authorization: expect.any(String), + }, + data: null, }); }); }); From a487fbb19e19eff08d4dcca788dfa78dda535f84 Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Tue, 31 Jan 2023 12:22:39 +1100 Subject: [PATCH 26/44] tests: test fixes [ci skip] --- tests/RPC/RPC.test.ts | 10 +++------ tests/RPC/RPCClient.test.ts | 41 ++++++++++++++----------------------- 2 files changed, 18 insertions(+), 33 deletions(-) diff --git a/tests/RPC/RPC.test.ts b/tests/RPC/RPC.test.ts index 1f5220d55..3fe8b4cd4 100644 --- a/tests/RPC/RPC.test.ts +++ b/tests/RPC/RPC.test.ts @@ -44,10 +44,7 @@ describe('RPC', () => { logger, }); - const callerInterface = await rpcClient.duplexStreamCaller( - methodName, - {}, - ); + const callerInterface = await rpcClient.duplexStreamCaller(methodName); const writer = callerInterface.writable.getWriter(); const reader = callerInterface.readable.getReader(); for (const value of values) { @@ -93,7 +90,7 @@ describe('RPC', () => { const callerInterface = await rpcClient.serverStreamCaller< number, number - >(methodName, value, {}); + >(methodName, value); const outputs: Array = []; for await (const num of callerInterface) { @@ -137,7 +134,7 @@ describe('RPC', () => { const callerInterface = await rpcClient.clientStreamCaller< number, number - >(methodName, {}); + >(methodName); const writer = callerInterface.writable.getWriter(); for (const value of values) { await writer.write(value); @@ -176,7 +173,6 @@ describe('RPC', () => { const result = await rpcClient.unaryCaller( methodName, value, - {}, ); expect(result).toStrictEqual(value); await rpcServer.destroy(); diff --git a/tests/RPC/RPCClient.test.ts b/tests/RPC/RPCClient.test.ts index accc20ed5..f1d78a327 100644 --- a/tests/RPC/RPCClient.test.ts +++ b/tests/RPC/RPCClient.test.ts @@ -40,7 +40,7 @@ describe(`${RPCClient.name}`, () => { const callerInterface = await rpcClient.duplexStreamCaller< JSONValue, JSONValue - >(methodName, { hello: 'world' }); + >(methodName); const reader = callerInterface.readable.getReader(); const writer = callerInterface.writable.getWriter(); while (true) { @@ -84,7 +84,7 @@ describe(`${RPCClient.name}`, () => { const callerInterface = await rpcClient.serverStreamCaller< JSONValue, JSONValue - >(methodName, params as JSONValue, {}); + >(methodName, params as JSONValue); const values: Array = []; for await (const value of callerInterface) { values.push(value); @@ -120,7 +120,7 @@ describe(`${RPCClient.name}`, () => { const callerInterface = await rpcClient.clientStreamCaller< JSONValue, JSONValue - >(methodName, {}); + >(methodName); const writer = callerInterface.writable.getWriter(); for (const param of params) { await writer.write(param as JSONValue); @@ -158,7 +158,6 @@ describe(`${RPCClient.name}`, () => { const result = await rpcClient.unaryCaller( methodName, params as JSONValue, - {}, ); expect(result).toStrictEqual(message.result); expect((await outputResult)[0]?.toString()).toStrictEqual( @@ -196,7 +195,7 @@ describe(`${RPCClient.name}`, () => { const callerInterface = await rpcClient.duplexStreamCaller< JSONValue, JSONValue - >(methodName, { hello: 'world' }); + >(methodName); const consumeToError = async () => { for await (const _ of callerInterface.readable) { // No touch, just consume @@ -226,16 +225,12 @@ describe(`${RPCClient.name}`, () => { logger, }); let count = 0; - await rpcClient.withDuplexCaller( - methodName, - async function* (output) { - for await (const value of output) { - count += 1; - yield value; - } - }, - {}, - ); + await rpcClient.withDuplexCaller(methodName, async function* (output) { + for await (const value of output) { + count += 1; + yield value; + } + }); const result = await outputResult; // We're just checking that it consuming the messages as expected expect(result.length).toEqual(messages.length); @@ -262,14 +257,9 @@ describe(`${RPCClient.name}`, () => { logger, }); let count = 0; - await rpcClient.withServerCaller( - methodName, - params, - async (output) => { - for await (const _ of output) count += 1; - }, - {}, - ); + await rpcClient.withServerCaller(methodName, params, async (output) => { + for await (const _ of output) count += 1; + }); const result = await outputResult; expect(count).toEqual(messages.length); expect(result.toString()).toStrictEqual( @@ -308,7 +298,6 @@ describe(`${RPCClient.name}`, () => { yield inputMessage; } }, - {}, ); const expectedResult = inputMessages.map((v) => { return JSON.stringify({ @@ -360,7 +349,7 @@ describe(`${RPCClient.name}`, () => { const callerInterface = await rpcClient.duplexStreamCaller< JSONValue, JSONValue - >(methodName, { hello: 'world' }); + >(methodName); const reader = callerInterface.readable.getReader(); const writer = callerInterface.writable.getWriter(); while (true) { @@ -426,7 +415,7 @@ describe(`${RPCClient.name}`, () => { const callerInterface = await rpcClient.duplexStreamCaller< JSONValue, JSONValue - >(methodName, { hello: 'world' }); + >(methodName); const reader = callerInterface.readable.getReader(); const writer = callerInterface.writable.getWriter(); while (true) { From 05e6c6bfbbba85980138525b29c8d4c76ea3e7e8 Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Tue, 31 Jan 2023 18:21:56 +1100 Subject: [PATCH 27/44] feat: client-client websocket communication - Related #500 - Related #501 - Related #502 [ci skip] --- package-lock.json | 46 ++++- package.json | 10 +- src/clientRPC/handlers/agentStatus.ts | 1 - src/clientRPC/utils.ts | 187 ++++++++++++++++++- tests/clientRPC/handlers/agentStatus.test.ts | 32 +++- tests/clientRPC/handlers/agentUnlock.test.ts | 30 ++- tests/clientRPC/websocket.test.ts | 92 +++++++++ 7 files changed, 376 insertions(+), 22 deletions(-) create mode 100644 tests/clientRPC/websocket.test.ts diff --git a/package-lock.json b/package-lock.json index eff34456b..c130defb4 100644 --- a/package-lock.json +++ b/package-lock.json @@ -26,6 +26,7 @@ "@peculiar/webcrypto": "^1.4.0", "@peculiar/x509": "^1.8.3", "@scure/bip39": "^1.1.0", + "@types/ws": "^8.5.4", "ajv": "^7.0.4", "bip39": "^3.0.3", "canonicalize": "^1.0.5", @@ -50,7 +51,8 @@ "threads": "^1.6.5", "tslib": "^2.4.0", "tsyringe": "^4.7.0", - "utp-native": "^2.5.3" + "utp-native": "^2.5.3", + "ws": "^8.12.0" }, "bin": { "pk": "dist/bin/polykey.js", @@ -3317,6 +3319,14 @@ "integrity": "sha512-Hl219/BT5fLAaz6NDkSuhzasy49dwQS/DSdu4MdggFB8zcXv7vflBI3xp7FEmkmdDkBUI2bPUNeMttp2knYdxw==", "dev": true }, + "node_modules/@types/ws": { + "version": "8.5.4", + "resolved": "https://registry.npmjs.org/@types/ws/-/ws-8.5.4.tgz", + "integrity": "sha512-zdQDHKUgcX/zBc4GrwsE/7dVdAD8JR4EuiAXiiUhhfyIJXXb2+PrGshFyeXWQPMmmZ2XxgaqclgpIC7eTXc1mg==", + "dependencies": { + "@types/node": "*" + } + }, "node_modules/@types/yargs": { "version": "17.0.10", "resolved": "https://registry.npmjs.org/@types/yargs/-/yargs-17.0.10.tgz", @@ -12023,6 +12033,26 @@ "node": "^12.13.0 || ^14.15.0 || >=16" } }, + "node_modules/ws": { + "version": "8.12.0", + "resolved": "https://registry.npmjs.org/ws/-/ws-8.12.0.tgz", + "integrity": "sha512-kU62emKIdKVeEIOIKVegvqpXMSTAMLJozpHZaJNDYqBjzlSYXQGviYwN1osDLJ9av68qHd4a2oSjd7yD4pacig==", + "engines": { + "node": ">=10.0.0" + }, + "peerDependencies": { + "bufferutil": "^4.0.1", + "utf-8-validate": ">=5.0.2" + }, + "peerDependenciesMeta": { + "bufferutil": { + "optional": true + }, + "utf-8-validate": { + "optional": true + } + } + }, "node_modules/xml": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/xml/-/xml-1.0.1.tgz", @@ -14559,6 +14589,14 @@ "integrity": "sha512-Hl219/BT5fLAaz6NDkSuhzasy49dwQS/DSdu4MdggFB8zcXv7vflBI3xp7FEmkmdDkBUI2bPUNeMttp2knYdxw==", "dev": true }, + "@types/ws": { + "version": "8.5.4", + "resolved": "https://registry.npmjs.org/@types/ws/-/ws-8.5.4.tgz", + "integrity": "sha512-zdQDHKUgcX/zBc4GrwsE/7dVdAD8JR4EuiAXiiUhhfyIJXXb2+PrGshFyeXWQPMmmZ2XxgaqclgpIC7eTXc1mg==", + "requires": { + "@types/node": "*" + } + }, "@types/yargs": { "version": "17.0.10", "resolved": "https://registry.npmjs.org/@types/yargs/-/yargs-17.0.10.tgz", @@ -21017,6 +21055,12 @@ "signal-exit": "^3.0.7" } }, + "ws": { + "version": "8.12.0", + "resolved": "https://registry.npmjs.org/ws/-/ws-8.12.0.tgz", + "integrity": "sha512-kU62emKIdKVeEIOIKVegvqpXMSTAMLJozpHZaJNDYqBjzlSYXQGviYwN1osDLJ9av68qHd4a2oSjd7yD4pacig==", + "requires": {} + }, "xml": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/xml/-/xml-1.0.1.tgz", diff --git a/package.json b/package.json index 7631c48c2..929c868c6 100644 --- a/package.json +++ b/package.json @@ -96,6 +96,7 @@ "@peculiar/webcrypto": "^1.4.0", "@peculiar/x509": "^1.8.3", "@scure/bip39": "^1.1.0", + "@types/ws": "^8.5.4", "ajv": "^7.0.4", "bip39": "^3.0.3", "canonicalize": "^1.0.5", @@ -118,13 +119,15 @@ "resource-counter": "^1.2.4", "sodium-native": "^3.4.1", "threads": "^1.6.5", - "utp-native": "^2.5.3", "tslib": "^2.4.0", - "tsyringe": "^4.7.0" + "tsyringe": "^4.7.0", + "utp-native": "^2.5.3", + "ws": "^8.12.0" }, "devDependencies": { "@babel/preset-env": "^7.13.10", "@fast-check/jest": "^1.1.0", + "@streamparser/json": "^0.0.12", "@swc/core": "^1.2.215", "@types/cross-spawn": "^6.0.2", "@types/google-protobuf": "^3.7.4", @@ -163,7 +166,6 @@ "ts-node": "^10.9.1", "tsconfig-paths": "^3.9.0", "typedoc": "^0.23.21", - "typescript": "^4.9.3", - "@streamparser/json": "^0.0.12" + "typescript": "^4.9.3" } } diff --git a/src/clientRPC/handlers/agentStatus.ts b/src/clientRPC/handlers/agentStatus.ts index 9e4855f0e..2d14ba558 100644 --- a/src/clientRPC/handlers/agentStatus.ts +++ b/src/clientRPC/handlers/agentStatus.ts @@ -37,7 +37,6 @@ const agentStatusCaller = async (metadata: POJO, rpcClient: RPCClient) => { const result = await rpcClient.unaryCaller( agentStatusName, null, - metadata, ); return { pid: result.pid, diff --git a/src/clientRPC/utils.ts b/src/clientRPC/utils.ts index 0aa45cd16..e72e14163 100644 --- a/src/clientRPC/utils.ts +++ b/src/clientRPC/utils.ts @@ -9,9 +9,17 @@ import type { JsonRpcResponse, MiddlewareFactory, } from '../RPC/types'; -import { TransformStream } from 'stream/web'; +import type { ReadableWritablePair } from 'stream/web'; +import type Logger from '@matrixai/logger'; +import type { ConnectionInfo, Host, Port } from '../network/types'; +import type RPCServer from '../RPC/RPCServer'; +import type { TLSSocket } from 'tls'; +import type { Server } from 'https'; +import { ReadableStream, TransformStream, WritableStream } from 'stream/web'; +import WebSocket, { WebSocketServer } from 'ws'; import * as clientErrors from '../client/errors'; import * as utils from '../utils'; +import { promise } from '../utils'; async function authenticate( sessionManager: SessionManager, @@ -156,10 +164,187 @@ function authenticationMiddlewareClient( }; } +function readableFromWebSocket( + ws: WebSocket, + logger: Logger, +): ReadableStream { + return new ReadableStream({ + start: (controller) => { + logger.info('starting'); + ws.on('message', (data) => { + logger.debug(`message: ${data.toString()}`); + ws.pause(); + const message = data as Buffer; + if (message.length === 0) { + logger.info('ENDING'); + ws.removeAllListeners('message'); + try { + controller.close(); + } catch { + // Ignore already closed + } + return; + } + controller.enqueue(message); + }); + ws.once('close', () => { + logger.info('closed'); + ws.removeAllListeners('message'); + try { + controller.close(); + } catch { + // Ignore already closed + } + }); + ws.once('error', (e) => { + controller.error(e); + }); + }, + cancel: () => { + logger.info('cancelled'); + ws.close(); + }, + pull: () => { + logger.debug('resuming'); + ws.resume(); + }, + }); +} + +function writeableFromWebSocket( + ws: WebSocket, + holdOpen: boolean, + logger: Logger, +): WritableStream { + return new WritableStream({ + start: (controller) => { + logger.info('starting'); + ws.once('error', (e) => { + logger.error(`error: ${e}`); + controller.error(e); + }); + ws.once('close', (code, reason) => { + logger.info( + `ws closing early! with code: ${code} and reason: ${reason.toString()}`, + ); + controller.error(Error('TMP WebSocket Closed early')); + }); + }, + close: () => { + logger.info('stream closing'); + ws.send(Buffer.from([])); + if (!holdOpen) ws.terminate(); + }, + abort: () => { + logger.info('aborting'); + ws.close(); + }, + write: async (chunk, controller) => { + logger.debug(`writing: ${chunk?.toString()}`); + const wait = promise(); + ws.send(chunk, (e) => { + if (e != null) { + logger.error(`error: ${e}`); + controller.error(e); + } + wait.resolveP(); + }); + await wait.p; + }, + }); +} + +function webSocketToWebStreamPair( + ws: WebSocket, + holdOpen: boolean, + logger: Logger, +): ReadableWritablePair { + return { + readable: readableFromWebSocket(ws, logger.getChild('readable')), + writable: writeableFromWebSocket(ws, holdOpen, logger.getChild('writable')), + }; +} + +function startConnection( + target: string, + logger: Logger, +): Promise> { + const ws = new WebSocket(target, { + // CheckServerIdentity: ( + // servername: string, + // cert: WebSocket.CertMeta, + // ): boolean => { + // console.log('CHECKING IDENTITY'); + // console.log(servername); + // console.log(cert); + // return false; + // }, + rejectUnauthorized: false, + // Ca: tlsConfig.certChainPem + }); + ws.once('close', () => logger.info('CLOSED')); + ws.once('upgrade', () => { + // Const tlsSocket = request.socket as TLSSocket; + // Console.log(tlsSocket.getPeerCertificate()); + logger.info('Test early cancellation'); + // Request.destroy(Error('some error')); + // tlsSocket.destroy(Error('some error')); + // ws.close(12345, 'some reason'); + // TODO: Use the existing verify method from the GRPC implementation + // TODO: Have this emit an error on verification failure. + // It's fine for the server side to close abruptly without error + }); + const prom = promise>(); + ws.once('open', () => { + logger.info('starting connection'); + prom.resolveP(webSocketToWebStreamPair(ws, true, logger)); + }); + return prom.p; +} + +function handleConnection(ws: WebSocket, logger: Logger): void { + ws.once('close', () => logger.info('CLOSED')); + const readable = readableFromWebSocket(ws, logger.getChild('readable')); + const writable = writeableFromWebSocket( + ws, + false, + logger.getChild('writable'), + ); + void readable.pipeTo(writable).catch((e) => logger.error(e)); +} + +function createClientServer( + server: Server, + rpcServer: RPCServer, + logger: Logger, +) { + logger.info('created server'); + const wss = new WebSocketServer({ + server, + }); + wss.on('error', (e) => logger.error(e)); + logger.info('created wss'); + wss.on('connection', (ws, req) => { + logger.info('connection!'); + const socket = req.socket as TLSSocket; + const streamPair = webSocketToWebStreamPair(ws, false, logger); + rpcServer.handleStream(streamPair, { + localHost: socket.localAddress! as Host, + localPort: socket.localPort! as Port, + remoteCertificates: socket.getPeerCertificate(), + remoteHost: socket.remoteAddress! as Host, + remotePort: socket.remotePort! as Port, + } as unknown as ConnectionInfo); + }); +} + export { authenticate, decodeAuth, encodeAuthFromPassword, authenticationMiddlewareServer, authenticationMiddlewareClient, + startConnection, + handleConnection, + createClientServer, }; diff --git a/tests/clientRPC/handlers/agentStatus.test.ts b/tests/clientRPC/handlers/agentStatus.test.ts index 6f0de9e39..be0f28556 100644 --- a/tests/clientRPC/handlers/agentStatus.test.ts +++ b/tests/clientRPC/handlers/agentStatus.test.ts @@ -1,7 +1,8 @@ -import type { ConnectionInfo } from '@/network/types'; +import type { Server } from 'https'; import fs from 'fs'; import path from 'path'; import os from 'os'; +import { createServer } from 'https'; import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; import { DB } from '@matrixai/db'; import KeyRing from '@/keys/KeyRing'; @@ -15,7 +16,8 @@ import { agentStatusCaller, } from '@/clientRPC/handlers/agentStatus'; import RPCClient from '@/RPC/RPCClient'; -import * as rpcTestUtils from '../../RPC/utils'; +import * as clientRPCUtils from '@/clientRPC/utils'; +import * as testsUtils from '../../utils'; describe('agentStatus', () => { const logger = new Logger('agentStatus test', LogLevel.WARN, [ @@ -27,6 +29,7 @@ describe('agentStatus', () => { let keyRing: KeyRing; let taskManager: TaskManager; let certManager: CertManager; + let server: Server; beforeEach(async () => { dataDir = await fs.promises.mkdtemp( @@ -53,8 +56,15 @@ describe('agentStatus', () => { taskManager, logger, }); + const tlsConfig = await testsUtils.createTLSConfig(keyRing.keyPair); + server = createServer({ + cert: tlsConfig.certChainPem, + key: tlsConfig.keyPrivatePem, + }); + server.listen(8080, '127.0.0.1'); }); afterEach(async () => { + server.close(); await certManager.stop(); await taskManager.stop(); await keyRing.stop(); @@ -70,18 +80,24 @@ describe('agentStatus', () => { container: { keyRing, certManager, - logger, + logger: logger.getChild('container'), }, - logger, + logger: logger.getChild('RPCServer'), }); rpcServer.registerUnaryHandler(agentStatusName, agentStatusHandler); + clientRPCUtils.createClientServer( + server, + rpcServer, + logger.getChild('server'), + ); const rpcClient = await RPCClient.createRPCClient({ streamPairCreateCallback: async () => { - const { clientPair, serverPair } = rpcTestUtils.createTapPairs(); - rpcServer.handleStream(serverPair, {} as ConnectionInfo); - return clientPair; + return clientRPCUtils.startConnection( + 'wss://localhost:8080', + logger.getChild('client'), + ); }, - logger, + logger: logger.getChild('RPCClient'), }); // Doing the test diff --git a/tests/clientRPC/handlers/agentUnlock.test.ts b/tests/clientRPC/handlers/agentUnlock.test.ts index 9dacadc70..e08f8e798 100644 --- a/tests/clientRPC/handlers/agentUnlock.test.ts +++ b/tests/clientRPC/handlers/agentUnlock.test.ts @@ -1,7 +1,8 @@ -import type { ConnectionInfo } from '@/network/types'; +import type { Server } from 'https'; import fs from 'fs'; import path from 'path'; import os from 'os'; +import { createServer } from 'https'; import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; import { DB } from '@matrixai/db'; import KeyRing from '@/keys/KeyRing'; @@ -17,10 +18,11 @@ import { import RPCClient from '@/RPC/RPCClient'; import { Session, SessionManager } from '@/sessions'; import * as abcUtils from '@/clientRPC/utils'; -import * as rpcTestUtils from '../../RPC/utils'; +import * as clientRPCUtils from '@/clientRPC/utils'; +import * as testsUtils from '../../utils'; -describe('agentStatus', () => { - const logger = new Logger('agentStatus test', LogLevel.WARN, [ +describe('agentUnlock', () => { + const logger = new Logger('agentUnlock test', LogLevel.INFO, [ new StreamHandler(), ]); const password = 'helloworld'; @@ -31,6 +33,7 @@ describe('agentStatus', () => { let certManager: CertManager; let session: Session; let sessionManager: SessionManager; + let server: Server; beforeEach(async () => { dataDir = await fs.promises.mkdtemp( @@ -67,8 +70,15 @@ describe('agentStatus', () => { keyRing, logger, }); + const tlsConfig = await testsUtils.createTLSConfig(keyRing.keyPair); + server = createServer({ + cert: tlsConfig.certChainPem, + key: tlsConfig.keyPrivatePem, + }); + server.listen(8080, '127.0.0.1'); }); afterEach(async () => { + server.close(); await certManager.stop(); await taskManager.stop(); await keyRing.stop(); @@ -90,11 +100,17 @@ describe('agentStatus', () => { rpcServer.registerMiddleware( abcUtils.authenticationMiddlewareServer(sessionManager, keyRing), ); + clientRPCUtils.createClientServer( + server, + rpcServer, + logger.getChild('server'), + ); const rpcClient = await RPCClient.createRPCClient({ streamPairCreateCallback: async () => { - const { clientPair, serverPair } = rpcTestUtils.createTapPairs(); - rpcServer.handleStream(serverPair, {} as ConnectionInfo); - return clientPair; + return clientRPCUtils.startConnection( + 'wss://localhost:8080', + logger.getChild('client'), + ); }, logger, }); diff --git a/tests/clientRPC/websocket.test.ts b/tests/clientRPC/websocket.test.ts new file mode 100644 index 000000000..345fe56cf --- /dev/null +++ b/tests/clientRPC/websocket.test.ts @@ -0,0 +1,92 @@ +import type { TLSConfig } from '@/network/types'; +import type { Server } from 'https'; +import fs from 'fs'; +import path from 'path'; +import os from 'os'; +import { createServer } from 'https'; +import Logger, { LogLevel, StreamHandler, formatting } from '@matrixai/logger'; +import RPCServer from '@/RPC/RPCServer'; +import RPCClient from '@/RPC/RPCClient'; +import { KeyRing } from '@/keys/index'; +import * as clientRPCUtils from '@/clientRPC/utils'; +import * as testsUtils from '../utils/index'; + +describe('websocket', () => { + const logger = new Logger('websocket test', LogLevel.WARN, [ + new StreamHandler( + formatting.format`${formatting.level}:${formatting.keys}:${formatting.msg}`, + ), + ]); + let dataDir: string; + let keyRing: KeyRing; + let tlsConfig: TLSConfig; + let server: Server; + + beforeEach(async () => { + dataDir = await fs.promises.mkdtemp( + path.join(os.tmpdir(), 'polykey-test-'), + ); + const keysPath = path.join(dataDir, 'keys'); + keyRing = await KeyRing.createKeyRing({ + keysPath: keysPath, + password: 'password', + logger: logger.getChild('keyRing'), + }); + tlsConfig = await testsUtils.createTLSConfig(keyRing.keyPair); + server = createServer({ + cert: tlsConfig.certChainPem, + key: tlsConfig.keyPrivatePem, + }); + server.listen(8080, '127.0.0.1'); + }); + afterEach(async () => { + server.close(); + await keyRing.stop(); + await fs.promises.rm(dataDir, { force: true, recursive: true }); + }); + + test('websocket should work with RPC', async () => { + // Setting up server + const rpcServer = new RPCServer({ + container: {}, + logger: logger.getChild('RPCServer'), + }); + rpcServer.registerUnaryHandler( + 'test1', + async (params, _container, _connectionInfo) => { + return params; + }, + ); + rpcServer.registerUnaryHandler('test2', async () => { + return { hello: 'not world' }; + }); + + clientRPCUtils.createClientServer( + server, + rpcServer, + logger.getChild('client'), + ); + + // Setting up client + const rpcClient = new RPCClient({ + logger: logger.getChild('RPCClient'), + streamPairCreateCallback: async () => { + return clientRPCUtils.startConnection( + 'wss://localhost:8080', + logger.getChild('Connection'), + ); + }, + }); + + // Making the call + await expect( + rpcClient.unaryCaller('test1', { hello: 'world2' }), + ).resolves.toStrictEqual({ hello: 'world2' }); + await expect( + rpcClient.unaryCaller('test2', { hello: 'world2' }), + ).resolves.toStrictEqual({ hello: 'not world' }); + await expect( + rpcClient.unaryCaller('test3', { hello: 'world2' }), + ).toReject(); + }); +}); From 54260208f501f1f5d8021d2f9ded4130593556aa Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Thu, 2 Feb 2023 18:01:29 +1100 Subject: [PATCH 28/44] feat: raw handlers - Related #500 - Related #501 --- src/RPC/RPCClient.ts | 39 ++- src/RPC/RPCServer.ts | 293 ++++++++++--------- src/RPC/types.ts | 7 +- src/RPC/utils.ts | 67 ++++- src/clientRPC/utils.ts | 50 +++- tests/RPC/RPC.test.ts | 59 +++- tests/RPC/RPCClient.test.ts | 104 ++++++- tests/RPC/RPCServer.test.ts | 112 ++++--- tests/RPC/utils.test.ts | 22 ++ tests/RPC/utils.ts | 24 +- tests/clientRPC/handlers/agentStatus.test.ts | 17 +- tests/clientRPC/handlers/agentUnlock.test.ts | 13 +- tests/clientRPC/websocket.test.ts | 14 +- 13 files changed, 598 insertions(+), 223 deletions(-) diff --git a/src/RPC/RPCClient.ts b/src/RPC/RPCClient.ts index 0931a8020..fae7fdbcf 100644 --- a/src/RPC/RPCClient.ts +++ b/src/RPC/RPCClient.ts @@ -1,4 +1,4 @@ -import type { StreamPairCreateCallback } from './types'; +import type { JsonRpcRequestMessage, StreamPairCreateCallback } from './types'; import type { JSONValue } from 'types'; import type { ReadableWritablePair } from 'stream/web'; import type { @@ -49,6 +49,24 @@ class RPCClient { this.logger.info(`Destroyed ${this.constructor.name}`); } + @ready(new rpcErrors.ErrorRpcDestroyed()) + public async rawStreamCaller( + method: string, + params: JSONValue, + ): Promise> { + const streamPair = await this.streamPairCreateCallback(); + const tempWriter = streamPair.writable.getWriter(); + const header: JsonRpcRequestMessage = { + jsonrpc: '2.0', + method, + params, + id: null, + }; + await tempWriter.write(Buffer.from(JSON.stringify(header))); + tempWriter.releaseLock(); + return streamPair; + } + @ready(new rpcErrors.ErrorRpcDestroyed()) public async duplexStreamCaller( method: string, @@ -136,6 +154,25 @@ class RPCClient { return output.value; } + @ready(new rpcErrors.ErrorRpcDestroyed()) + public async withRawStreamCaller( + method: string, + params: JSONValue, + f: (output: AsyncGenerator) => AsyncGenerator, + ) { + const callerInterface = await this.rawStreamCaller(method, params); + const outputGenerator = async function* () { + for await (const value of callerInterface.readable) { + yield value; + } + }; + const writer = callerInterface.writable.getWriter(); + for await (const value of f(outputGenerator())) { + await writer.write(value); + } + await writer.close(); + } + @ready(new rpcErrors.ErrorRpcDestroyed()) public async withDuplexCaller( method: string, diff --git a/src/RPC/RPCServer.ts b/src/RPC/RPCServer.ts index e6166fba8..c11a45c3a 100644 --- a/src/RPC/RPCServer.ts +++ b/src/RPC/RPCServer.ts @@ -6,6 +6,7 @@ import type { JsonRpcResponse, JsonRpcResponseError, JsonRpcResponseResult, + RawDuplexStreamHandler, ServerStreamHandler, UnaryHandler, } from './types'; @@ -44,8 +45,7 @@ class RPCServer { // Properties protected container: POJO; protected logger: Logger; - protected handlerMap: Map> = - new Map(); + protected handlerMap: Map = new Map(); protected activeStreams: Set> = new Set(); protected events: EventTarget = new EventTarget(); @@ -72,12 +72,118 @@ class RPCServer { this.logger.info(`Destroyed ${this.constructor.name}`); } + @ready(new rpcErrors.ErrorRpcDestroyed()) + public registerRawStreamHandler( + method: string, + handler: RawDuplexStreamHandler, + ) { + this.handlerMap.set(method, handler); + } + @ready(new rpcErrors.ErrorRpcDestroyed()) public registerDuplexStreamHandler( method: string, handler: DuplexStreamHandler, ) { - this.handlerMap.set(method, handler); + // This needs to handle all the message parsing and conversion from + // generators to the raw streams. + + const rawSteamHandler: RawDuplexStreamHandler = ( + [input, header], + container, + connectionInfo, + ctx, + ) => { + // Middleware + const outputTransformStream = new rpcUtils.JsonMessageToJsonStream(); + const outputReadableSteam = outputTransformStream.readable; + let forwardStream = input.pipeThrough( + new rpcUtils.JsonToJsonMessageStream( + rpcUtils.parseJsonRpcRequest, + undefined, + header, + ), + ); + let reverseStream = outputTransformStream.writable; + for (const middlewareFactory of this.middleware) { + const middleware = middlewareFactory(); + forwardStream = forwardStream.pipeThrough(middleware.forward); + void middleware.reverse.readable.pipeTo(reverseStream).catch(() => {}); + reverseStream = middleware.reverse.writable; + } + const events = this.events; + const outputGen = async function* (): AsyncGenerator< + JsonRpcResponse + > { + if (ctx.signal.aborted) throw ctx.signal.reason; + const dataGen = async function* () { + for await (const data of forwardStream) { + yield data.params as I; + } + }; + for await (const response of handler( + dataGen(), + container, + connectionInfo, + ctx, + )) { + const responseMessage: JsonRpcResponseResult = { + jsonrpc: '2.0', + result: response, + id: null, + }; + yield responseMessage; + } + }; + const outputGenerator = outputGen(); + const reverseMiddlewareStream = new ReadableStream< + JsonRpcResponse + >({ + pull: async (controller) => { + try { + const { value, done } = await outputGenerator.next(); + if (done) { + controller.close(); + return; + } + controller.enqueue(value); + } catch (e) { + if (rpcUtils.isReturnableError(e)) { + // We want to convert this error to an error message and pass it along + const rpcError: JsonRpcError = { + code: e.exitCode ?? sysexits.UNKNOWN, + message: e.description ?? '', + data: rpcUtils.fromError(e), + }; + const rpcErrorMessage: JsonRpcResponseError = { + jsonrpc: '2.0', + error: rpcError, + id: null, + }; + controller.enqueue(rpcErrorMessage); + } else { + // These errors are emitted to the event system + events.dispatchEvent( + new rpcUtils.RPCErrorEvent({ + detail: { + error: e, + }, + }), + ); + } + controller.close(); + } + }, + cancel: async (reason) => { + await outputGenerator.throw(reason); + }, + }); + void reverseMiddlewareStream.pipeTo(reverseStream).catch(() => {}); + + return outputReadableSteam; + }; + + this.registerRawStreamHandler(method, rawSteamHandler); } @ready(new rpcErrors.ErrorRpcDestroyed()) @@ -96,7 +202,7 @@ class RPCServer { break; } }; - this.handlerMap.set(method, wrapperDuplex); + this.registerDuplexStreamHandler(method, wrapperDuplex); } @ready(new rpcErrors.ErrorRpcDestroyed()) @@ -115,7 +221,7 @@ class RPCServer { break; } }; - this.handlerMap.set(method, wrapperDuplex); + this.registerDuplexStreamHandler(method, wrapperDuplex); } @ready(new rpcErrors.ErrorRpcDestroyed()) @@ -131,7 +237,7 @@ class RPCServer { ) { yield handler(input, container, connectionInfo, ctx); }; - this.handlerMap.set(method, wrapperDuplex); + this.registerDuplexStreamHandler(method, wrapperDuplex); } @ready(new rpcErrors.ErrorRpcDestroyed()) @@ -142,143 +248,60 @@ class RPCServer { // This will take a buffer stream of json messages and set up service // handling for it. // Constructing the PromiseCancellable for tracking the active stream - let resolve: (value: void | PromiseLike) => void; - const abortController = new AbortController(); const handlerProm: PromiseCancellable = new PromiseCancellable( - (resolve_) => { - resolve = resolve_; + (resolve, reject, signal) => { + const prom = (async () => { + const { firstMessageProm, headTransformStream } = + rpcUtils.extractFirstMessageTransform(rpcUtils.parseJsonRpcRequest); + const inputStreamEndProm = streamPair.readable.pipeTo( + headTransformStream.writable, + ); + const inputStream = headTransformStream.readable; + // Read a single empty value to consume the first message + const reader = inputStream.getReader(); + await reader.read(); + reader.releaseLock(); + const leadingMetadataMessage = await firstMessageProm; + // If the stream ends early then we just stop processing + if (leadingMetadataMessage == null) { + await inputStream.cancel(); + await streamPair.writable.close(); + await inputStreamEndProm; + return; + } + const method = leadingMetadataMessage.method; + const handler = this.handlerMap.get(method); + if (handler == null) { + await inputStream.cancel(); + await streamPair.writable.close(); + await inputStreamEndProm; + return; + } + if (signal.aborted) { + await inputStream.cancel(); + await streamPair.writable.close(); + await inputStreamEndProm; + return; + } + const outputStream = handler( + [inputStream, leadingMetadataMessage], + this.container, + connectionInfo, + { signal }, + ); + await Promise.allSettled([ + inputStreamEndProm, + outputStream.pipeTo(streamPair.writable), + ]); + })(); + prom.then(resolve, reject); }, - abortController, ); // Putting the PromiseCancellable into the active streams map this.activeStreams.add(handlerProm); void handlerProm .finally(() => this.activeStreams.delete(handlerProm)) .catch(() => {}); - // Setting up middleware - let forwardStream = streamPair.readable.pipeThrough( - new rpcUtils.JsonToJsonMessageStream(rpcUtils.parseJsonRpcRequest), - ); - const outputTransformStream = new rpcUtils.JsonMessageToJsonStream(); - void outputTransformStream.readable - .pipeTo(streamPair.writable) - .catch(() => {}); - let reverseStream = outputTransformStream.writable; - for (const middlewareFactory of this.middleware) { - const middleware = middlewareFactory(); - forwardStream = forwardStream.pipeThrough(middleware.forward); - void middleware.reverse.readable.pipeTo(reverseStream).catch(() => {}); - reverseStream = middleware.reverse.writable; - } - // While ReadableStream can be converted to AsyncIterable, we want it as - // a generator. - const inputGen = async function* () { - for await (const dataMessage of forwardStream) { - yield dataMessage; - } - }; - const container = this.container; - const handlerMap = this.handlerMap; - const ctx = { signal: abortController.signal }; - const events = this.events; - const outputGen = async function* (): AsyncGenerator< - JsonRpcResponse - > { - // Step 1, authentication and establishment - // read the first message, lets assume the first message is always leading - // metadata. - const input = inputGen(); - if (ctx.signal.aborted) throw ctx.signal.reason; - const leadingMetadataMessage = await input.next(); - // If the stream ends early then we just stop processing - if (leadingMetadataMessage.done === true) return; - const method = leadingMetadataMessage.value.method; - const initialParams = leadingMetadataMessage.value.params; - const dataGen = async function* () { - yield initialParams as JSONValue; - for await (const data of input) { - yield data.params as JSONValue; - } - }; - const handler = handlerMap.get(method); - if (handler == null) { - // Failed to find handler, this is an error. We should respond with - // an error message. - throw new rpcErrors.ErrorRpcHandlerMissing( - `No handler registered for method: ${method}`, - ); - } - if (ctx.signal.aborted) throw ctx.signal.reason; - for await (const response of handler( - dataGen(), - container, - connectionInfo, - ctx, - )) { - const responseMessage: JsonRpcResponseResult = { - jsonrpc: '2.0', - result: response, - id: null, - }; - yield responseMessage; - } - }; - - const outputGenerator = outputGen(); - - const reverseMiddlewareStream = new ReadableStream< - JsonRpcResponse - >({ - pull: async (controller) => { - try { - const { value, done } = await outputGenerator.next(); - if (done) { - try { - controller.close(); - } catch { - // Ignore already closed error - } - resolve(); - return; - } - controller.enqueue(value); - } catch (e) { - if (rpcUtils.isReturnableError(e)) { - // We want to convert this error to an error message and pass it along - const rpcError: JsonRpcError = { - code: e.exitCode ?? sysexits.UNKNOWN, - message: e.description ?? '', - data: rpcUtils.fromError(e), - }; - const rpcErrorMessage: JsonRpcResponseError = { - jsonrpc: '2.0', - error: rpcError, - id: null, - }; - controller.enqueue(rpcErrorMessage); - } else { - // These errors are emitted to the event system - events.dispatchEvent( - new rpcUtils.RPCErrorEvent({ - detail: { - error: e, - }, - }), - ); - } - try { - controller.close(); - } catch { - // Ignore already closed error - } - resolve(); - } - }, - cancel: async (reason) => { - await outputGenerator.throw(reason); - }, - }); - void reverseMiddlewareStream.pipeTo(reverseStream).catch(() => {}); } @ready(new rpcErrors.ErrorRpcDestroyed()) diff --git a/src/RPC/types.ts b/src/RPC/types.ts index b1554469d..06a7f15fd 100644 --- a/src/RPC/types.ts +++ b/src/RPC/types.ts @@ -1,7 +1,7 @@ import type { JSONValue, POJO } from '../types'; import type { ConnectionInfo } from '../network/types'; import type { ContextCancellable } from '../contexts/types'; -import type { ReadableWritablePair } from 'stream/web'; +import type { ReadableStream, ReadableWritablePair } from 'stream/web'; /** * This is the JSON RPC request object. this is the generic message type used for the RPC. @@ -106,6 +106,10 @@ type Handler = ( connectionInfo: ConnectionInfo, ctx: ContextCancellable, ) => O; +type RawDuplexStreamHandler = Handler< + [ReadableStream, JsonRpcRequest], + ReadableStream +>; type DuplexStreamHandler = Handler< AsyncGenerator, AsyncGenerator @@ -141,6 +145,7 @@ export type { JsonRpcRequest, JsonRpcResponse, JsonRpcMessage, + RawDuplexStreamHandler, DuplexStreamHandler, ServerStreamHandler, ClientStreamHandler, diff --git a/src/RPC/utils.ts b/src/RPC/utils.ts index 9c87232c7..bacd30224 100644 --- a/src/RPC/utils.ts +++ b/src/RPC/utils.ts @@ -33,6 +33,7 @@ class JsonToJsonMessage constructor( protected messageParser: (message: unknown) => T, protected byteLimit: number, + protected firstMessage: T | undefined, ) {} protected parser = new jsonStreamParsers.JSONParser({ @@ -41,6 +42,7 @@ class JsonToJsonMessage }); start: TransformerStartCallback = async (controller) => { + if (this.firstMessage != null) controller.enqueue(this.firstMessage); this.parser.onValue = (value) => { const jsonMessage = this.messageParser(value.value); controller.enqueue(jsonMessage); @@ -69,8 +71,9 @@ class JsonToJsonMessageStream extends TransformStream< constructor( messageParser: (message: unknown) => T, byteLimit: number = 1024 * 1024, + firstMessage?: T, ) { - super(new JsonToJsonMessage(messageParser, byteLimit)); + super(new JsonToJsonMessage(messageParser, byteLimit, firstMessage)); } } @@ -564,6 +567,67 @@ class QueueMergingTransformStream extends TransformStream { } } +function extractFirstMessageTransform( + messageParser: (message: unknown) => T, + byteLimit: number = 1024 * 1024, +) { + const parser = new jsonStreamParsers.JSONParser({ + separator: '', + paths: ['$'], + }); + const messageProm = promise(); + let bytesWritten = 0; + let lastChunk: Uint8Array | null = null; + let passThrough = false; + const headTransformStream = new TransformStream({ + start: (controller) => { + parser.onValue = (value) => { + let jsonMessage: T; + try { + jsonMessage = messageParser(value.value); + } catch (e) { + const error = new rpcErrors.ErrorRpcParse(undefined, { cause: e }); + messageProm.rejectP(error); + controller.error(error); + return; + } + messageProm.resolveP(jsonMessage); + const firstMessageBuffer = Buffer.from(JSON.stringify(jsonMessage)); + const difference = bytesWritten - firstMessageBuffer.length; + // Write empty value for the first read that initializes the stream + controller.enqueue(new Uint8Array()); + if (difference > 0) { + controller.enqueue( + lastChunk?.slice(lastChunk?.byteLength - difference), + ); + } + parser.end(); + passThrough = true; + }; + }, + transform: (chunk, controller) => { + if (passThrough) { + controller.enqueue(chunk); + return; + } + try { + bytesWritten += chunk.byteLength; + lastChunk = chunk; + parser.write(chunk); + } catch (e) { + // Ignore error + } + if (bytesWritten > byteLimit) { + messageProm.rejectP(new rpcErrors.ErrorRpcMessageLength()); + } + }, + flush: () => { + messageProm.resolveP(undefined); + }, + }); + return { headTransformStream, firstMessageProm: messageProm.p }; +} + export { JsonToJsonMessageStream, JsonMessageToJsonStream, @@ -582,4 +646,5 @@ export { RPCErrorEvent, controllerTransformationFactory, QueueMergingTransformStream, + extractFirstMessageTransform, }; diff --git a/src/clientRPC/utils.ts b/src/clientRPC/utils.ts index e72e14163..631825de3 100644 --- a/src/clientRPC/utils.ts +++ b/src/clientRPC/utils.ts @@ -15,6 +15,8 @@ import type { ConnectionInfo, Host, Port } from '../network/types'; import type RPCServer from '../RPC/RPCServer'; import type { TLSSocket } from 'tls'; import type { Server } from 'https'; +import type net from 'net'; +import type https from 'https'; import { ReadableStream, TransformStream, WritableStream } from 'stream/web'; import WebSocket, { WebSocketServer } from 'ws'; import * as clientErrors from '../client/errors'; @@ -171,7 +173,7 @@ function readableFromWebSocket( return new ReadableStream({ start: (controller) => { logger.info('starting'); - ws.on('message', (data) => { + const messageHandler = (data) => { logger.debug(`message: ${data.toString()}`); ws.pause(); const message = data as Buffer; @@ -186,10 +188,11 @@ function readableFromWebSocket( return; } controller.enqueue(message); - }); + }; + ws.on('message', messageHandler); ws.once('close', () => { logger.info('closed'); - ws.removeAllListeners('message'); + ws.removeListener('message', messageHandler); try { controller.close(); } catch { @@ -266,10 +269,11 @@ function webSocketToWebStreamPair( } function startConnection( - target: string, + host: string, + port: number, logger: Logger, ): Promise> { - const ws = new WebSocket(target, { + const ws = new WebSocket(`wss://${host}:${port}`, { // CheckServerIdentity: ( // servername: string, // cert: WebSocket.CertMeta, @@ -283,17 +287,17 @@ function startConnection( // Ca: tlsConfig.certChainPem }); ws.once('close', () => logger.info('CLOSED')); - ws.once('upgrade', () => { - // Const tlsSocket = request.socket as TLSSocket; - // Console.log(tlsSocket.getPeerCertificate()); - logger.info('Test early cancellation'); - // Request.destroy(Error('some error')); - // tlsSocket.destroy(Error('some error')); - // ws.close(12345, 'some reason'); - // TODO: Use the existing verify method from the GRPC implementation - // TODO: Have this emit an error on verification failure. - // It's fine for the server side to close abruptly without error - }); + // Ws.once('upgrade', () => { + // // Const tlsSocket = request.socket as TLSSocket; + // // Console.log(tlsSocket.getPeerCertificate()); + // logger.info('Test early cancellation'); + // // Request.destroy(Error('some error')); + // // tlsSocket.destroy(Error('some error')); + // // ws.close(12345, 'some reason'); + // // TODO: Use the existing verify method from the GRPC implementation + // // TODO: Have this emit an error on verification failure. + // // It's fine for the server side to close abruptly without error + // }); const prom = promise>(); ws.once('open', () => { logger.info('starting connection'); @@ -336,6 +340,19 @@ function createClientServer( remotePort: socket.remotePort! as Port, } as unknown as ConnectionInfo); }); + wss.once('close', () => { + wss.removeAllListeners('error'); + wss.removeAllListeners('connection'); + }); + return wss; +} + +async function listen(server: https.Server, host?: string, port?: number) { + await new Promise((resolve) => { + server.listen(port, host ?? '127.0.0.1', undefined, () => resolve()); + }); + const addressInfo = server.address() as net.AddressInfo; + return addressInfo.port; } export { @@ -347,4 +364,5 @@ export { startConnection, handleConnection, createClientServer, + listen, }; diff --git a/tests/RPC/RPC.test.ts b/tests/RPC/RPC.test.ts index 3fe8b4cd4..33572944a 100644 --- a/tests/RPC/RPC.test.ts +++ b/tests/RPC/RPC.test.ts @@ -1,6 +1,8 @@ import type { ClientStreamHandler, DuplexStreamHandler, + JsonRpcRequest, + RawDuplexStreamHandler, ServerStreamHandler, UnaryHandler, } from '@/RPC/types'; @@ -17,6 +19,60 @@ describe('RPC', () => { const methodName = 'testMethod'; + testProp( + 'RPC communication with raw stream', + [rpcTestUtils.rawDataArb], + async (inputData) => { + const [outputResult, outputWriterStream] = + rpcTestUtils.streamToArray(); + const { clientPair, serverPair } = rpcTestUtils.createTapPairs< + Uint8Array, + Uint8Array + >(); + + const container = {}; + const rpcServer = await RPCServer.createRPCServer({ container, logger }); + let header: JsonRpcRequest | undefined; + const rawHandler: RawDuplexStreamHandler = ( + [input, header_], + _container, + _connectionInfo, + _ctx, + ) => { + header = header_; + return input; + }; + + rpcServer.registerRawStreamHandler(methodName, rawHandler); + rpcServer.handleStream(serverPair, {} as ConnectionInfo); + + const rpcClient = await RPCClient.createRPCClient({ + streamPairCreateCallback: async () => clientPair, + logger, + }); + + const callerInterface = await rpcClient.rawStreamCaller(methodName, { + hello: 'world', + }); + const writer = callerInterface.writable.getWriter(); + const pipeProm = callerInterface.readable.pipeTo(outputWriterStream); + for (const value of inputData) { + await writer.write(value); + } + await writer.close(); + const expectedHeader: JsonRpcRequest = { + jsonrpc: '2.0', + method: methodName, + params: { hello: 'world' }, + id: null, + }; + expect(header).toStrictEqual(expectedHeader); + expect(await outputResult).toStrictEqual(inputData); + await pipeProm; + await rpcServer.destroy(); + await rpcClient.destroy(); + }, + ); testProp( 'RPC communication with duplex stream', [fc.array(rpcTestUtils.safeJsonValueArb, { minLength: 1 })], @@ -59,7 +115,6 @@ describe('RPC', () => { await rpcClient.destroy(); }, ); - testProp( 'RPC communication with server stream', [fc.integer({ min: 1, max: 100 })], @@ -101,7 +156,6 @@ describe('RPC', () => { await rpcClient.destroy(); }, ); - testProp( 'RPC communication with client stream', [fc.array(fc.integer(), { minLength: 1 }).noShrink()], @@ -147,7 +201,6 @@ describe('RPC', () => { await rpcClient.destroy(); }, ); - testProp( 'RPC communication with unary call', [rpcTestUtils.safeJsonValueArb], diff --git a/tests/RPC/RPCClient.test.ts b/tests/RPC/RPCClient.test.ts index f1d78a327..18984e06f 100644 --- a/tests/RPC/RPCClient.test.ts +++ b/tests/RPC/RPCClient.test.ts @@ -5,7 +5,7 @@ import type { JsonRpcRequestMessage, JsonRpcResponse, } from '@/RPC/types'; -import { TransformStream } from 'stream/web'; +import { TransformStream, ReadableStream } from 'stream/web'; import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; import { testProp, fc } from '@fast-check/jest'; import RPCClient from '@/RPC/RPCClient'; @@ -25,6 +25,57 @@ describe(`${RPCClient.name}`, () => { }) .noShrink(); + testProp( + 'raw duplex caller', + [ + rpcTestUtils.safeJsonValueArb, + rpcTestUtils.rawDataArb, + rpcTestUtils.rawDataArb, + ], + async (headerParams, inputData, outputData) => { + const [inputResult, inputWritableStream] = + rpcTestUtils.streamToArray(); + const [outputResult, outputWritableStream] = + rpcTestUtils.streamToArray(); + const streamPair: ReadableWritablePair = { + readable: new ReadableStream({ + start: (controller) => { + for (const datum of outputData) { + controller.enqueue(datum); + } + controller.close(); + }, + }), + writable: inputWritableStream, + }; + const rpcClient = await RPCClient.createRPCClient({ + streamPairCreateCallback: async () => streamPair, + logger, + }); + const callerInterface = await rpcClient.rawStreamCaller( + 'testMethod', + headerParams, + ); + await callerInterface.readable.pipeTo(outputWritableStream); + const writer = callerInterface.writable.getWriter(); + for (const inputDatum of inputData) { + await writer.write(inputDatum); + } + await writer.close(); + + const expectedHeader: JsonRpcRequest = { + jsonrpc: '2.0', + method: methodName, + params: headerParams, + id: null, + }; + expect(await inputResult).toStrictEqual([ + Buffer.from(JSON.stringify(expectedHeader)), + ...inputData, + ]); + expect(await outputResult).toStrictEqual(outputData); + }, + ); testProp('generic duplex caller', [specificMessageArb], async (messages) => { const inputStream = rpcTestUtils.jsonRpcStream(messages); const [outputResult, outputStream] = @@ -209,6 +260,57 @@ describe(`${RPCClient.name}`, () => { await rpcClient.destroy(); }, ); + testProp( + 'withRawStreamCaller', + [ + rpcTestUtils.safeJsonValueArb, + rpcTestUtils.rawDataArb, + rpcTestUtils.rawDataArb, + ], + async (headerParams, inputData, outputData) => { + const [inputResult, inputWritableStream] = + rpcTestUtils.streamToArray(); + const streamPair: ReadableWritablePair = { + readable: new ReadableStream({ + start: (controller) => { + for (const datum of outputData) { + controller.enqueue(datum); + } + controller.close(); + }, + }), + writable: inputWritableStream, + }; + const rpcClient = await RPCClient.createRPCClient({ + streamPairCreateCallback: async () => streamPair, + logger, + }); + const outputResult: Array = []; + await rpcClient.withRawStreamCaller( + methodName, + headerParams, + async function* (output) { + for await (const outputValue of output) { + outputResult.push(outputValue); + } + for (const inputDatum of inputData) { + yield inputDatum; + } + }, + ); + const expectedHeader: JsonRpcRequest = { + jsonrpc: '2.0', + method: methodName, + params: headerParams, + id: null, + }; + expect(await inputResult).toStrictEqual([ + Buffer.from(JSON.stringify(expectedHeader)), + ...inputData, + ]); + expect(outputResult).toStrictEqual(outputData); + }, + ); testProp( 'withDuplexCaller', [fc.array(rpcTestUtils.jsonRpcResponseResultArb(), { minLength: 1 })], diff --git a/tests/RPC/RPCServer.test.ts b/tests/RPC/RPCServer.test.ts index 0af80fb19..8c772f1cc 100644 --- a/tests/RPC/RPCServer.test.ts +++ b/tests/RPC/RPCServer.test.ts @@ -5,6 +5,7 @@ import type { JsonRpcRequest, JsonRpcResponse, JsonRpcResponseError, + RawDuplexStreamHandler, ServerStreamHandler, UnaryHandler, } from '@/RPC/types'; @@ -12,7 +13,7 @@ import type { JSONValue } from '@/types'; import type { ConnectionInfo, Host, Port } from '@/network/types'; import type { NodeId } from '@/ids'; import type { ReadableWritablePair } from 'stream/web'; -import { TransformStream } from 'stream/web'; +import { TransformStream, ReadableStream } from 'stream/web'; import { fc, testProp } from '@fast-check/jest'; import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; import RPCServer from '@/RPC/RPCServer'; @@ -23,14 +24,83 @@ describe(`${RPCServer.name}`, () => { const logger = new Logger(`${RPCServer.name} Test`, LogLevel.WARN, [ new StreamHandler(), ]); - const methodName = 'testMethod'; const specificMessageArb = fc .array(rpcTestUtils.jsonRpcRequestMessageArb(fc.constant(methodName)), { minLength: 5, }) .noShrink(); + const singleNumberMessageArb = fc.array( + rpcTestUtils.jsonRpcRequestMessageArb( + fc.constant(methodName), + fc.integer({ min: 1, max: 20 }), + ), + { + minLength: 2, + maxLength: 10, + }, + ); + const errorArb = fc.oneof( + fc.constant(new rpcErrors.ErrorRpcParse()), + fc.constant(new rpcErrors.ErrorRpcHandlerMissing()), + fc.constant(new rpcErrors.ErrorRpcProtocal()), + fc.constant(new rpcErrors.ErrorRpcMessageLength()), + fc.constant(new rpcErrors.ErrorRpcRemoteError()), + ); + const validToken = 'VALIDTOKEN'; + const invalidTokenMessageArb = rpcTestUtils.jsonRpcRequestMessageArb( + fc.constant('testMethod'), + fc.record({ + metadata: fc.record({ + token: fc.string().filter((v) => v !== validToken), + }), + data: rpcTestUtils.safeJsonValueArb, + }), + ); + testProp( + 'can stream data with raw duplex stream handler', + [specificMessageArb], + async (messages) => { + const stream = rpcTestUtils + .jsonRpcStream(messages) + .pipeThrough( + new rpcTestUtils.BufferStreamToSnippedStream([4, 7, 13, 2, 6]), + ); + const container = {}; + const rpcServer = await RPCServer.createRPCServer({ container, logger }); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const readWriteStream: ReadableWritablePair = { + readable: stream, + writable: outputStream, + }; + + const rawDuplexHandler: RawDuplexStreamHandler = ( + [input], + _container, + _connectionInfo, + _ctx, + ) => { + void (async () => { + for await (const _ of input) { + // No touch, only consume + } + })().catch(() => {}); + return new ReadableStream({ + start: (controller) => { + controller.enqueue(Buffer.from('hello world!')); + controller.close(); + }, + }); + }; + + rpcServer.registerRawStreamHandler(methodName, rawDuplexHandler); + rpcServer.handleStream(readWriteStream, {} as ConnectionInfo); + await outputResult; + await rpcServer.destroy(); + }, + { numRuns: 1 }, + ); testProp( 'can stream data with duplex stream handler', [specificMessageArb], @@ -58,7 +128,6 @@ describe(`${RPCServer.name}`, () => { await rpcServer.destroy(); }, ); - testProp( 'can stream data with client stream handler', [specificMessageArb], @@ -87,18 +156,6 @@ describe(`${RPCServer.name}`, () => { await rpcServer.destroy(); }, ); - - const singleNumberMessageArb = fc.array( - rpcTestUtils.jsonRpcRequestMessageArb( - fc.constant(methodName), - fc.integer({ min: 1, max: 20 }), - ), - { - minLength: 2, - maxLength: 10, - }, - ); - testProp( 'can stream data with server stream handler', [singleNumberMessageArb], @@ -125,7 +182,6 @@ describe(`${RPCServer.name}`, () => { await rpcServer.destroy(); }, ); - testProp( 'can stream data with server stream handler', [specificMessageArb], @@ -154,7 +210,6 @@ describe(`${RPCServer.name}`, () => { await rpcServer.destroy(); }, ); - testProp( 'Handler is provided with container', [specificMessageArb], @@ -186,7 +241,6 @@ describe(`${RPCServer.name}`, () => { await rpcServer.destroy(); }, ); - testProp( 'Handler is provided with connectionInfo', [specificMessageArb], @@ -223,7 +277,6 @@ describe(`${RPCServer.name}`, () => { expect(handledConnectionInfo).toBe(connectionInfo); }, ); - // Problem with the tap stream. It seems to block the whole stream. // If I don't pipe the tap to the output we actually iterate over some data. testProp.skip( @@ -271,7 +324,6 @@ describe(`${RPCServer.name}`, () => { await rpcServer.destroy(); }, ); - testProp('Handler yields nothing', [specificMessageArb], async (messages) => { const stream = rpcTestUtils.jsonRpcStream(messages); const container = {}; @@ -295,14 +347,6 @@ describe(`${RPCServer.name}`, () => { // We're just expecting no errors await rpcServer.destroy(); }); - - const errorArb = fc.oneof( - fc.constant(new rpcErrors.ErrorRpcParse()), - fc.constant(new rpcErrors.ErrorRpcHandlerMissing()), - fc.constant(new rpcErrors.ErrorRpcProtocal()), - fc.constant(new rpcErrors.ErrorRpcMessageLength()), - fc.constant(new rpcErrors.ErrorRpcRemoteError()), - ); testProp( 'should send error message', [specificMessageArb, errorArb], @@ -458,17 +502,6 @@ describe(`${RPCServer.name}`, () => { ); await rpcServer.destroy(); }); - const validToken = 'VALIDTOKEN'; - const invalidTokenMessageArb = rpcTestUtils.jsonRpcRequestMessageArb( - undefined, - fc.record({ - metadata: fc.record({ - token: fc.string().filter((v) => v !== validToken), - }), - data: rpcTestUtils.safeJsonValueArb, - }), - ); - testProp( 'forward middleware authentication', [invalidTokenMessageArb], @@ -544,7 +577,6 @@ describe(`${RPCServer.name}`, () => { await rpcServer.destroy(); }, ); - // TODO: // - Test odd conditions for handlers, like extra messages where 1 is expected. // - Expectations can't be inside the handlers otherwise they're caught. diff --git a/tests/RPC/utils.test.ts b/tests/RPC/utils.test.ts index 4f8f12206..63a1bdfec 100644 --- a/tests/RPC/utils.test.ts +++ b/tests/RPC/utils.test.ts @@ -136,6 +136,28 @@ describe('utils tests', () => { { numRuns: 1000 }, ); + testProp( + 'can get the head message', + [rpcTestUtils.jsonMessagesArb], + async (messages) => { + const { firstMessageProm, headTransformStream } = + rpcUtils.extractFirstMessageTransform(rpcUtils.parseJsonRpcRequest); + const parsedStream = rpcTestUtils + .jsonRpcStream(messages) + .pipeThrough(new rpcTestUtils.BufferStreamToSnippedStream([7])) + .pipeThrough(headTransformStream) + .pipeThrough( + new rpcUtils.JsonToJsonMessageStream(rpcUtils.parseJsonRpcMessage), + ); // Converting back. + + expect(await firstMessageProm).toStrictEqual(messages[0]); + expect(await AsyncIterable.as(parsedStream).toArray()).toStrictEqual( + messages.slice(1), + ); + }, + { numRuns: 1000 }, + ); + // TODO: // - Test for badly structured data }); diff --git a/tests/RPC/utils.ts b/tests/RPC/utils.ts index e77b4e9a6..f8c4aed2f 100644 --- a/tests/RPC/utils.ts +++ b/tests/RPC/utils.ts @@ -4,7 +4,7 @@ import type { TransformerTransformCallback, ReadableWritablePair, } from 'stream/web'; -import type { POJO } from '@/types'; +import type { JSONValue, POJO } from '@/types'; import type { JsonRpcError, JsonRpcMessage, @@ -15,7 +15,6 @@ import type { JsonRpcResponse, JsonRpcRequest, } from '@/RPC/types'; -import type { JsonValue } from 'fast-check'; import { ReadableStream, WritableStream, TransformStream } from 'stream/web'; import { fc } from '@fast-check/jest'; import * as utils from '@/utils'; @@ -106,13 +105,13 @@ const jsonRpcStream = (messages: Array) => { const safeJsonValueArb = fc .jsonValue() - .map((value) => JSON.parse(JSON.stringify(value))); + .map((value) => JSON.parse(JSON.stringify(value)) as JSONValue); const idArb = fc.oneof(fc.string(), fc.integer(), fc.constant(null)); const jsonRpcRequestMessageArb = ( method: fc.Arbitrary = fc.string(), - params: fc.Arbitrary = safeJsonValueArb, + params: fc.Arbitrary = safeJsonValueArb, ) => fc .record( @@ -126,11 +125,11 @@ const jsonRpcRequestMessageArb = ( requiredKeys: ['jsonrpc', 'method', 'id'], }, ) - .noShrink() as fc.Arbitrary; + .noShrink() as fc.Arbitrary>; const jsonRpcRequestNotificationArb = ( method: fc.Arbitrary = fc.string(), - params: fc.Arbitrary = safeJsonValueArb, + params: fc.Arbitrary = safeJsonValueArb, ) => fc .record( @@ -147,7 +146,7 @@ const jsonRpcRequestNotificationArb = ( const jsonRpcRequestArb = ( method: fc.Arbitrary = fc.string(), - params: fc.Arbitrary = safeJsonValueArb, + params: fc.Arbitrary = safeJsonValueArb, ) => fc .oneof( @@ -157,7 +156,7 @@ const jsonRpcRequestArb = ( .noShrink() as fc.Arbitrary; const jsonRpcResponseResultArb = ( - result: fc.Arbitrary = safeJsonValueArb, + result: fc.Arbitrary = safeJsonValueArb, ) => fc .record({ @@ -191,7 +190,7 @@ const jsonRpcResponseErrorArb = (error?: Error) => .noShrink() as fc.Arbitrary; const jsonRpcResponseArb = ( - result: fc.Arbitrary = safeJsonValueArb, + result: fc.Arbitrary = safeJsonValueArb, ) => fc .oneof(jsonRpcResponseResultArb(result), jsonRpcResponseErrorArb()) @@ -199,8 +198,8 @@ const jsonRpcResponseArb = ( const jsonRpcMessageArb = ( method: fc.Arbitrary = fc.string(), - params: fc.Arbitrary = safeJsonValueArb, - result: fc.Arbitrary = safeJsonValueArb, + params: fc.Arbitrary = safeJsonValueArb, + result: fc.Arbitrary = safeJsonValueArb, ) => fc .oneof(jsonRpcRequestArb(method, params), jsonRpcResponseArb(result)) @@ -214,6 +213,8 @@ const jsonMessagesArb = fc .array(jsonRpcRequestMessageArb(), { minLength: 2 }) .noShrink(); +const rawDataArb = fc.array(fc.uint8Array({ minLength: 1 }), { minLength: 1 }); + function streamToArray(): [Promise>, WritableStream] { const outputArray: Array = []; const result = utils.promise>(); @@ -293,6 +294,7 @@ export { jsonRpcMessageArb, snippingPatternArb, jsonMessagesArb, + rawDataArb, streamToArray, TapTransformerStream, createTapPairs, diff --git a/tests/clientRPC/handlers/agentStatus.test.ts b/tests/clientRPC/handlers/agentStatus.test.ts index be0f28556..78b29a537 100644 --- a/tests/clientRPC/handlers/agentStatus.test.ts +++ b/tests/clientRPC/handlers/agentStatus.test.ts @@ -1,4 +1,5 @@ import type { Server } from 'https'; +import type { WebSocketServer } from 'ws'; import fs from 'fs'; import path from 'path'; import os from 'os'; @@ -30,6 +31,9 @@ describe('agentStatus', () => { let taskManager: TaskManager; let certManager: CertManager; let server: Server; + let wss: WebSocketServer; + const host = '127.0.0.1'; + let port: number; beforeEach(async () => { dataDir = await fs.promises.mkdtemp( @@ -61,10 +65,11 @@ describe('agentStatus', () => { cert: tlsConfig.certChainPem, key: tlsConfig.keyPrivatePem, }); - server.listen(8080, '127.0.0.1'); + port = await clientRPCUtils.listen(server, host); }); afterEach(async () => { - server.close(); + wss?.close(); + server?.close(); await certManager.stop(); await taskManager.stop(); await keyRing.stop(); @@ -74,7 +79,7 @@ describe('agentStatus', () => { recursive: true, }); }); - test('get status', async () => { + test('get status %s', async () => { // Setup const rpcServer = await RPCServer.createRPCServer({ container: { @@ -85,7 +90,7 @@ describe('agentStatus', () => { logger: logger.getChild('RPCServer'), }); rpcServer.registerUnaryHandler(agentStatusName, agentStatusHandler); - clientRPCUtils.createClientServer( + wss = clientRPCUtils.createClientServer( server, rpcServer, logger.getChild('server'), @@ -93,13 +98,13 @@ describe('agentStatus', () => { const rpcClient = await RPCClient.createRPCClient({ streamPairCreateCallback: async () => { return clientRPCUtils.startConnection( - 'wss://localhost:8080', + host, + port, logger.getChild('client'), ); }, logger: logger.getChild('RPCClient'), }); - // Doing the test const result = await agentStatusCaller({}, rpcClient); expect(result).toStrictEqual({ diff --git a/tests/clientRPC/handlers/agentUnlock.test.ts b/tests/clientRPC/handlers/agentUnlock.test.ts index e08f8e798..5e6cdcf45 100644 --- a/tests/clientRPC/handlers/agentUnlock.test.ts +++ b/tests/clientRPC/handlers/agentUnlock.test.ts @@ -1,4 +1,5 @@ import type { Server } from 'https'; +import type { WebSocketServer } from 'ws'; import fs from 'fs'; import path from 'path'; import os from 'os'; @@ -22,7 +23,7 @@ import * as clientRPCUtils from '@/clientRPC/utils'; import * as testsUtils from '../../utils'; describe('agentUnlock', () => { - const logger = new Logger('agentUnlock test', LogLevel.INFO, [ + const logger = new Logger('agentUnlock test', LogLevel.WARN, [ new StreamHandler(), ]); const password = 'helloworld'; @@ -34,6 +35,8 @@ describe('agentUnlock', () => { let session: Session; let sessionManager: SessionManager; let server: Server; + let wss: WebSocketServer; + let port: number; beforeEach(async () => { dataDir = await fs.promises.mkdtemp( @@ -75,9 +78,10 @@ describe('agentUnlock', () => { cert: tlsConfig.certChainPem, key: tlsConfig.keyPrivatePem, }); - server.listen(8080, '127.0.0.1'); + port = await clientRPCUtils.listen(server, '127.0.0.1'); }); afterEach(async () => { + wss?.close(); server.close(); await certManager.stop(); await taskManager.stop(); @@ -100,7 +104,7 @@ describe('agentUnlock', () => { rpcServer.registerMiddleware( abcUtils.authenticationMiddlewareServer(sessionManager, keyRing), ); - clientRPCUtils.createClientServer( + wss = clientRPCUtils.createClientServer( server, rpcServer, logger.getChild('server'), @@ -108,7 +112,8 @@ describe('agentUnlock', () => { const rpcClient = await RPCClient.createRPCClient({ streamPairCreateCallback: async () => { return clientRPCUtils.startConnection( - 'wss://localhost:8080', + '127.0.0.1', + port, logger.getChild('client'), ); }, diff --git a/tests/clientRPC/websocket.test.ts b/tests/clientRPC/websocket.test.ts index 345fe56cf..79d939953 100644 --- a/tests/clientRPC/websocket.test.ts +++ b/tests/clientRPC/websocket.test.ts @@ -1,5 +1,6 @@ import type { TLSConfig } from '@/network/types'; import type { Server } from 'https'; +import type { WebSocketServer } from 'ws'; import fs from 'fs'; import path from 'path'; import os from 'os'; @@ -21,6 +22,9 @@ describe('websocket', () => { let keyRing: KeyRing; let tlsConfig: TLSConfig; let server: Server; + let wss: WebSocketServer; + const host = '127.0.0.1'; + let port: number; beforeEach(async () => { dataDir = await fs.promises.mkdtemp( @@ -37,9 +41,10 @@ describe('websocket', () => { cert: tlsConfig.certChainPem, key: tlsConfig.keyPrivatePem, }); - server.listen(8080, '127.0.0.1'); + port = await clientRPCUtils.listen(server, host); }); afterEach(async () => { + wss?.close(); server.close(); await keyRing.stop(); await fs.promises.rm(dataDir, { force: true, recursive: true }); @@ -61,18 +66,19 @@ describe('websocket', () => { return { hello: 'not world' }; }); - clientRPCUtils.createClientServer( + wss = clientRPCUtils.createClientServer( server, rpcServer, logger.getChild('client'), ); // Setting up client - const rpcClient = new RPCClient({ + const rpcClient = await RPCClient.createRPCClient({ logger: logger.getChild('RPCClient'), streamPairCreateCallback: async () => { return clientRPCUtils.startConnection( - 'wss://localhost:8080', + host, + port, logger.getChild('Connection'), ); }, From ceddf68653d87d124e34a84ab692e155f9bc32fc Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Mon, 6 Feb 2023 19:16:31 +1100 Subject: [PATCH 29/44] feat: client side manifest and typed callers - Related #501 [ci skip] --- src/RPC/RPCClient.ts | 97 +++- src/RPC/types.ts | 160 ++++++- src/RPC/utils.ts | 10 + tests/RPC/RPC.test.ts | 5 + tests/RPC/RPCClient.test.ts | 459 ++++++++++++++++++- tests/clientRPC/handlers/agentStatus.test.ts | 1 + tests/clientRPC/websocket.test.ts | 1 + 7 files changed, 719 insertions(+), 14 deletions(-) diff --git a/src/RPC/RPCClient.ts b/src/RPC/RPCClient.ts index fae7fdbcf..06513e0e3 100644 --- a/src/RPC/RPCClient.ts +++ b/src/RPC/RPCClient.ts @@ -1,28 +1,44 @@ -import type { JsonRpcRequestMessage, StreamPairCreateCallback } from './types'; +import type { + HandlerType, + JsonRpcRequestMessage, + Manifest, + MapWithHandlers, + StreamPairCreateCallback, +} from './types'; import type { JSONValue } from 'types'; -import type { ReadableWritablePair } from 'stream/web'; +import type { + ReadableWritablePair, + ReadableStream, + WritableStream, +} from 'stream/web'; import type { JsonRpcRequest, JsonRpcResponse, MiddlewareFactory, + MapHandlers, } from './types'; import { CreateDestroy, ready } from '@matrixai/async-init/dist/CreateDestroy'; import Logger from '@matrixai/logger'; import * as rpcErrors from './errors'; import * as rpcUtils from './utils'; +import { getHandlerTypes } from './utils'; -interface RPCClient extends CreateDestroy {} +// eslint-disable-next-line +interface RPCClient extends CreateDestroy {} @CreateDestroy() -class RPCClient { - static async createRPCClient({ +class RPCClient { + static async createRPCClient({ + manifest, streamPairCreateCallback, logger = new Logger(this.name), }: { + manifest: M; streamPairCreateCallback: StreamPairCreateCallback; logger: Logger; }) { logger.info(`Creating ${this.name}`); const rpcClient = new this({ + manifest, streamPairCreateCallback, logger, }); @@ -32,16 +48,64 @@ class RPCClient { protected logger: Logger; protected streamPairCreateCallback: StreamPairCreateCallback; + protected callerTypes: Record; + // Method proxies + protected methodsProxy = new Proxy( + {}, + { + get: (_, method) => { + if (typeof method === 'symbol') throw Error('invalid symbol'); + switch (this.callerTypes[method]) { + case 'DUPLEX': + return () => this.duplexStreamCaller(method); + case 'SERVER': + return (params) => this.serverStreamCaller(method, params); + case 'CLIENT': + return () => this.clientStreamCaller(method); + case 'UNARY': + return (params) => this.unaryCaller(method, params); + case 'RAW': + return (params) => this.rawStreamCaller(method, params); + default: + return; + } + }, + }, + ); + protected withMethodsProxy = new Proxy( + {}, + { + get: (_, method) => { + if (typeof method === 'symbol') throw Error('invalid symbol'); + switch (this.callerTypes[method]) { + case 'DUPLEX': + return (f) => this.withDuplexCaller(method, f); + case 'SERVER': + return (params, f) => this.withServerCaller(method, params, f); + case 'CLIENT': + return (f) => this.withClientCaller(method, f); + case 'RAW': + return (params, f) => this.withRawStreamCaller(method, params, f); + case 'UNARY': + default: + return; + } + }, + }, + ); public constructor({ + manifest, streamPairCreateCallback, logger, }: { + manifest: M; streamPairCreateCallback: StreamPairCreateCallback; logger: Logger; }) { - this.logger = logger; + this.callerTypes = getHandlerTypes(manifest); this.streamPairCreateCallback = streamPairCreateCallback; + this.logger = logger; } public async destroy(): Promise { @@ -49,6 +113,16 @@ class RPCClient { this.logger.info(`Destroyed ${this.constructor.name}`); } + @ready(new rpcErrors.ErrorRpcDestroyed()) + public get methods(): MapHandlers { + return this.methodsProxy as MapHandlers; + } + + @ready(new rpcErrors.ErrorRpcDestroyed()) + public get withMethods(): MapWithHandlers { + return this.withMethodsProxy as MapWithHandlers; + } + @ready(new rpcErrors.ErrorRpcDestroyed()) public async rawStreamCaller( method: string, @@ -109,7 +183,7 @@ class RPCClient { public async serverStreamCaller( method: string, parameters: I, - ) { + ): Promise> { const callerInterface = await this.duplexStreamCaller(method); const writer = callerInterface.writable.getWriter(); await writer.write(parameters); @@ -121,7 +195,10 @@ class RPCClient { @ready(new rpcErrors.ErrorRpcDestroyed()) public async clientStreamCaller( method: string, - ) { + ): Promise<{ + output: Promise; + writable: WritableStream; + }> { const callerInterface = await this.duplexStreamCaller(method); const reader = callerInterface.readable.getReader(); const output = reader.read().then(({ value, done }) => { @@ -159,7 +236,7 @@ class RPCClient { method: string, params: JSONValue, f: (output: AsyncGenerator) => AsyncGenerator, - ) { + ): Promise { const callerInterface = await this.rawStreamCaller(method, params); const outputGenerator = async function* () { for await (const value of callerInterface.readable) { @@ -196,7 +273,7 @@ class RPCClient { method: string, parameters: I, f: (output: AsyncGenerator) => Promise, - ) { + ): Promise { const callerInterface = await this.serverStreamCaller( method, parameters, diff --git a/src/RPC/types.ts b/src/RPC/types.ts index 06a7f15fd..1995030dd 100644 --- a/src/RPC/types.ts +++ b/src/RPC/types.ts @@ -1,7 +1,11 @@ import type { JSONValue, POJO } from '../types'; import type { ConnectionInfo } from '../network/types'; import type { ContextCancellable } from '../contexts/types'; -import type { ReadableStream, ReadableWritablePair } from 'stream/web'; +import type { + ReadableStream, + ReadableWritablePair, + WritableStream, +} from 'stream/web'; /** * This is the JSON RPC request object. this is the generic message type used for the RPC. @@ -136,6 +140,155 @@ type MiddlewareFactory = () => { reverse: ReadableWritablePair; }; +type DuplexStreamCaller< + I extends JSONValue, + O extends JSONValue, +> = () => Promise>; + +type ServerStreamCaller = ( + parameters: I, +) => Promise>; + +type ClientStreamCaller< + I extends JSONValue, + O extends JSONValue, +> = () => Promise<{ + output: Promise; + writable: WritableStream; +}>; + +type UnaryCaller = ( + parameters: I, +) => Promise; + +type RawStreamCaller = ( + params: JSONValue, +) => Promise>; + +type ConvertDuplexStreamHandler = T extends DuplexStreamHandler< + infer I, + infer O +> + ? DuplexStreamCaller + : never; + +type ConvertServerStreamHandler = T extends ServerStreamHandler< + infer I, + infer O +> + ? ServerStreamCaller + : never; + +type ConvertClientStreamHandler = T extends ClientStreamHandler< + infer I, + infer O +> + ? ClientStreamCaller + : never; + +type ConvertUnaryCaller = T extends UnaryHandler + ? UnaryCaller + : never; + +type ConvertHandler = T extends DuplexStreamHandler + ? ConvertDuplexStreamHandler + : T extends ServerStreamHandler + ? ConvertServerStreamHandler + : T extends ClientStreamHandler + ? ConvertClientStreamHandler + : T extends UnaryHandler + ? ConvertUnaryCaller + : T extends RawDuplexStreamHandler + ? RawStreamCaller + : never; + +type WithDuplexStreamCaller = ( + f: (output: AsyncGenerator) => AsyncGenerator, +) => Promise; + +type WithServerStreamCaller = ( + parameters: I, + f: (output: AsyncGenerator) => Promise, +) => Promise; + +type WithClientStreamCaller = ( + f: () => AsyncGenerator, +) => Promise; + +type WithRawStreamCaller = ( + params: JSONValue, + f: (output: AsyncGenerator) => AsyncGenerator, +) => Promise; + +type ConvertWithDuplexStreamHandler = T extends DuplexStreamHandler< + infer I, + infer O +> + ? WithDuplexStreamCaller + : never; + +type ConvertWithServerStreamHandler = T extends ServerStreamHandler< + infer I, + infer O +> + ? WithServerStreamCaller + : never; + +type ConvertWithClientStreamHandler = T extends ClientStreamHandler< + infer I, + infer O +> + ? WithClientStreamCaller + : never; + +type ConvertWithHandler = T extends DuplexStreamHandler + ? ConvertWithDuplexStreamHandler + : T extends ServerStreamHandler + ? ConvertWithServerStreamHandler + : T extends ClientStreamHandler + ? ConvertWithClientStreamHandler + : T extends RawDuplexStreamHandler + ? WithRawStreamCaller + : never; + +type HandlerType = 'DUPLEX' | 'SERVER' | 'CLIENT' | 'UNARY' | 'RAW'; + +type ManifestItem = + | { + type: 'DUPLEX'; + handler: DuplexStreamHandler; + } + | { + type: 'SERVER'; + handler: ServerStreamHandler; + } + | { + type: 'CLIENT'; + handler: ClientStreamHandler; + } + | { + type: 'UNARY'; + handler: UnaryHandler; + } + | { + type: 'RAW'; + handler: RawDuplexStreamHandler; + }; + +type Manifest = Record>; + +type ExtractHandler = T extends ManifestItem + ? T['handler'] + : never; + +type MapHandlers = { + [P in keyof T]: ConvertHandler>; +}; + +type MapWithHandlers = { + [P in keyof T]: ConvertWithHandler>; +}; + export type { JsonRpcRequestMessage, JsonRpcRequestNotification, @@ -152,4 +305,9 @@ export type { UnaryHandler, StreamPairCreateCallback, MiddlewareFactory, + HandlerType, + ManifestItem, + Manifest, + MapHandlers, + MapWithHandlers, }; diff --git a/src/RPC/utils.ts b/src/RPC/utils.ts index bacd30224..9a05e9b28 100644 --- a/src/RPC/utils.ts +++ b/src/RPC/utils.ts @@ -16,6 +16,7 @@ import type { } from 'RPC/types'; import type { JSONValue } from '../types'; import type { JsonValue } from 'fast-check'; +import type { HandlerType, Manifest } from 'RPC/types'; import { TransformStream } from 'stream/web'; import { AbstractError } from '@matrixai/errors'; import * as rpcErrors from './errors'; @@ -628,6 +629,14 @@ function extractFirstMessageTransform( return { headTransformStream, firstMessageProm: messageProm.p }; } +function getHandlerTypes(manifest: Manifest): Record { + const out: Record = {}; + for (const [k, v] of Object.entries(manifest)) { + out[k] = v.type; + } + return out; +} + export { JsonToJsonMessageStream, JsonMessageToJsonStream, @@ -647,4 +656,5 @@ export { controllerTransformationFactory, QueueMergingTransformStream, extractFirstMessageTransform, + getHandlerTypes, }; diff --git a/tests/RPC/RPC.test.ts b/tests/RPC/RPC.test.ts index 33572944a..03eb95d31 100644 --- a/tests/RPC/RPC.test.ts +++ b/tests/RPC/RPC.test.ts @@ -47,6 +47,7 @@ describe('RPC', () => { rpcServer.handleStream(serverPair, {} as ConnectionInfo); const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, streamPairCreateCallback: async () => clientPair, logger, }); @@ -96,6 +97,7 @@ describe('RPC', () => { rpcServer.handleStream(serverPair, {} as ConnectionInfo); const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, streamPairCreateCallback: async () => clientPair, logger, }); @@ -138,6 +140,7 @@ describe('RPC', () => { rpcServer.handleStream(serverPair, {} as ConnectionInfo); const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, streamPairCreateCallback: async () => clientPair, logger, }); @@ -181,6 +184,7 @@ describe('RPC', () => { rpcServer.handleStream(serverPair, {} as ConnectionInfo); const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, streamPairCreateCallback: async () => clientPair, logger, }); @@ -219,6 +223,7 @@ describe('RPC', () => { rpcServer.handleStream(serverPair, {} as ConnectionInfo); const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, streamPairCreateCallback: async () => clientPair, logger, }); diff --git a/tests/RPC/RPCClient.test.ts b/tests/RPC/RPCClient.test.ts index 18984e06f..863bb99fa 100644 --- a/tests/RPC/RPCClient.test.ts +++ b/tests/RPC/RPCClient.test.ts @@ -4,6 +4,7 @@ import type { JsonRpcRequest, JsonRpcRequestMessage, JsonRpcResponse, + ManifestItem, } from '@/RPC/types'; import { TransformStream, ReadableStream } from 'stream/web'; import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; @@ -49,6 +50,7 @@ describe(`${RPCClient.name}`, () => { writable: inputWritableStream, }; const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, streamPairCreateCallback: async () => streamPair, logger, }); @@ -85,6 +87,7 @@ describe(`${RPCClient.name}`, () => { writable: outputStream, }; const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, streamPairCreateCallback: async () => streamPair, logger, }); @@ -120,7 +123,7 @@ describe(`${RPCClient.name}`, () => { }); testProp( 'generic server stream caller', - [specificMessageArb, fc.jsonValue()], + [specificMessageArb, rpcTestUtils.safeJsonValueArb], async (messages, params) => { const inputStream = rpcTestUtils.jsonRpcStream(messages); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); @@ -129,6 +132,7 @@ describe(`${RPCClient.name}`, () => { writable: outputStream, }; const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, streamPairCreateCallback: async () => streamPair, logger, }); @@ -155,7 +159,10 @@ describe(`${RPCClient.name}`, () => { ); testProp( 'generic client stream caller', - [rpcTestUtils.jsonRpcResponseResultArb(), fc.array(fc.jsonValue())], + [ + rpcTestUtils.jsonRpcResponseResultArb(), + fc.array(rpcTestUtils.safeJsonValueArb), + ], async (message, params) => { const inputStream = rpcTestUtils.jsonRpcStream([message]); const [outputResult, outputStream] = @@ -165,6 +172,7 @@ describe(`${RPCClient.name}`, () => { writable: outputStream, }; const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, streamPairCreateCallback: async () => streamPair, logger, }); @@ -194,7 +202,7 @@ describe(`${RPCClient.name}`, () => { ); testProp( 'generic unary caller', - [rpcTestUtils.jsonRpcResponseResultArb(), fc.jsonValue()], + [rpcTestUtils.jsonRpcResponseResultArb(), rpcTestUtils.safeJsonValueArb], async (message, params) => { const inputStream = rpcTestUtils.jsonRpcStream([message]); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); @@ -203,6 +211,7 @@ describe(`${RPCClient.name}`, () => { writable: outputStream, }; const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, streamPairCreateCallback: async () => streamPair, logger, }); @@ -240,6 +249,7 @@ describe(`${RPCClient.name}`, () => { writable: outputStream, }; const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, streamPairCreateCallback: async () => streamPair, logger, }); @@ -282,6 +292,7 @@ describe(`${RPCClient.name}`, () => { writable: inputWritableStream, }; const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, streamPairCreateCallback: async () => streamPair, logger, }); @@ -323,6 +334,7 @@ describe(`${RPCClient.name}`, () => { writable: outputStream, }; const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, streamPairCreateCallback: async () => streamPair, logger, }); @@ -355,6 +367,7 @@ describe(`${RPCClient.name}`, () => { writable: outputStream, }; const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, streamPairCreateCallback: async () => streamPair, logger, }); @@ -390,6 +403,7 @@ describe(`${RPCClient.name}`, () => { writable: outputStream, }; const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, streamPairCreateCallback: async () => streamPair, logger, }); @@ -428,6 +442,7 @@ describe(`${RPCClient.name}`, () => { writable: outputStream, }; const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, streamPairCreateCallback: async () => streamPair, logger, }); @@ -494,6 +509,7 @@ describe(`${RPCClient.name}`, () => { writable: outputStream, }; const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, streamPairCreateCallback: async () => streamPair, logger, }); @@ -534,4 +550,441 @@ describe(`${RPCClient.name}`, () => { await rpcClient.destroy(); }, ); + testProp( + 'manifest duplex call', + [ + fc.array(rpcTestUtils.jsonRpcResponseResultArb(fc.string()), { + minLength: 5, + }), + ], + async (messages) => { + const inputStream = rpcTestUtils.jsonRpcStream(messages); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const streamPair: ReadableWritablePair = { + readable: inputStream, + writable: outputStream, + }; + const duplex: ManifestItem = { + type: 'DUPLEX', + handler: async function* (input) { + yield* input; + }, + }; + const rpcClient = await RPCClient.createRPCClient({ + manifest: { + duplex, + }, + streamPairCreateCallback: async () => streamPair, + logger, + }); + const callerInterface = await rpcClient.methods.duplex(); + const reader = callerInterface.readable.getReader(); + const writer = callerInterface.writable.getWriter(); + while (true) { + const { value, done } = await reader.read(); + if (done) { + // We have to end the writer otherwise the stream never closes + await writer.close(); + break; + } + await writer.write(value); + } + const expectedMessages: Array = messages.map( + (v) => { + const request: JsonRpcRequestMessage = { + jsonrpc: '2.0', + method: 'duplex', + id: null, + ...(v.result === undefined ? {} : { params: v.result }), + }; + return request; + }, + ); + const outputMessages = (await outputResult).map((v) => + JSON.parse(v.toString()), + ); + expect(outputMessages).toStrictEqual(expectedMessages); + + await rpcClient.destroy(); + }, + ); + testProp( + 'manifest server call', + [specificMessageArb, fc.string()], + async (messages, params) => { + const inputStream = rpcTestUtils.jsonRpcStream(messages); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const streamPair: ReadableWritablePair = { + readable: inputStream, + writable: outputStream, + }; + const server: ManifestItem = { + type: 'SERVER', + handler: async function* (input) { + yield* input; + }, + }; + const rpcClient = await RPCClient.createRPCClient({ + manifest: { + server, + }, + streamPairCreateCallback: async () => streamPair, + logger, + }); + const callerInterface = await rpcClient.methods.server(params); + const values: Array = []; + for await (const value of callerInterface) { + values.push(value); + } + const expectedValues = messages.map((v) => v.result); + expect(values).toStrictEqual(expectedValues); + expect((await outputResult)[0]?.toString()).toStrictEqual( + JSON.stringify({ + method: 'server', + jsonrpc: '2.0', + id: null, + params, + }), + ); + await rpcClient.destroy(); + }, + ); + testProp( + 'manifest client call', + [ + rpcTestUtils.jsonRpcResponseResultArb(fc.string()), + fc.array(fc.string(), { minLength: 5 }), + ], + async (message, params) => { + const inputStream = rpcTestUtils.jsonRpcStream([message]); + const [outputResult, outputStream] = + rpcTestUtils.streamToArray(); + const streamPair: ReadableWritablePair = { + readable: inputStream, + writable: outputStream, + }; + const client: ManifestItem = { + type: 'CLIENT', + handler: async (_) => { + return 'hello'; + }, + }; + const rpcClient = await RPCClient.createRPCClient({ + manifest: { + client, + }, + streamPairCreateCallback: async () => streamPair, + logger, + }); + const callerInterface = await rpcClient.methods.client(); + const writer = callerInterface.writable.getWriter(); + for (const param of params) { + await writer.write(param); + } + await writer.close(); + expect(await callerInterface.output).toStrictEqual(message.result); + const expectedOutput = params.map((v) => + JSON.stringify({ + method: 'client', + jsonrpc: '2.0', + id: null, + params: v, + }), + ); + expect((await outputResult).map((v) => v.toString())).toStrictEqual( + expectedOutput, + ); + await rpcClient.destroy(); + }, + ); + testProp( + 'manifest unary call', + [rpcTestUtils.jsonRpcResponseResultArb().noShrink(), fc.string()], + async (message, params) => { + const inputStream = rpcTestUtils.jsonRpcStream([message]); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const streamPair: ReadableWritablePair = { + readable: inputStream, + writable: outputStream, + }; + const unary: ManifestItem = { + type: 'UNARY', + handler: async (input) => input, + }; + const rpcClient = await RPCClient.createRPCClient({ + manifest: { + unary, + }, + streamPairCreateCallback: async () => streamPair, + logger, + }); + const result = await rpcClient.methods.unary(params); + expect(result).toStrictEqual(message.result); + expect((await outputResult)[0]?.toString()).toStrictEqual( + JSON.stringify({ + method: 'unary', + jsonrpc: '2.0', + id: null, + params: params, + }), + ); + await rpcClient.destroy(); + }, + ); + testProp( + 'manifest raw duplex caller', + [ + rpcTestUtils.safeJsonValueArb, + rpcTestUtils.rawDataArb, + rpcTestUtils.rawDataArb, + ], + async (headerParams, inputData, outputData) => { + const [inputResult, inputWritableStream] = + rpcTestUtils.streamToArray(); + const [outputResult, outputWritableStream] = + rpcTestUtils.streamToArray(); + const streamPair: ReadableWritablePair = { + readable: new ReadableStream({ + start: (controller) => { + for (const datum of outputData) { + controller.enqueue(datum); + } + controller.close(); + }, + }), + writable: inputWritableStream, + }; + const raw: ManifestItem = { + type: 'RAW', + handler: ([input]) => input, + }; + const rpcClient = await RPCClient.createRPCClient({ + manifest: { + raw, + }, + streamPairCreateCallback: async () => streamPair, + logger, + }); + const callerInterface = await rpcClient.methods.raw(headerParams); + await callerInterface.readable.pipeTo(outputWritableStream); + const writer = callerInterface.writable.getWriter(); + for (const inputDatum of inputData) { + await writer.write(inputDatum); + } + await writer.close(); + + const expectedHeader: JsonRpcRequest = { + jsonrpc: '2.0', + method: 'raw', + params: headerParams, + id: null, + }; + expect(await inputResult).toStrictEqual([ + Buffer.from(JSON.stringify(expectedHeader)), + ...inputData, + ]); + expect(await outputResult).toStrictEqual(outputData); + }, + ); + testProp( + 'manifest withDuplex caller', + [ + fc.array(rpcTestUtils.jsonRpcResponseResultArb(fc.string()), { + minLength: 1, + }), + ], + async (messages) => { + const inputStream = rpcTestUtils.jsonRpcStream(messages); + const [outputResult, outputStream] = + rpcTestUtils.streamToArray(); + const streamPair: ReadableWritablePair = { + readable: inputStream, + writable: outputStream, + }; + const duplex: ManifestItem = { + type: 'DUPLEX', + handler: async function* (input) { + yield* input; + }, + }; + const rpcClient = await RPCClient.createRPCClient({ + manifest: { + duplex, + }, + streamPairCreateCallback: async () => streamPair, + logger, + }); + let count = 0; + await rpcClient.withMethods.duplex(async function* (output) { + for await (const value of output) { + count += 1; + yield value; + } + }); + const result = await outputResult; + // We're just checking that it's consuming the messages as expected + expect(result.length).toEqual(messages.length); + expect(count).toEqual(messages.length); + await rpcClient.destroy(); + }, + ); + testProp( + 'manifest withServer caller', + [ + fc.array(rpcTestUtils.jsonRpcResponseResultArb(), { minLength: 1 }), + fc.string(), + ], + async (messages, params) => { + const inputStream = rpcTestUtils.jsonRpcStream(messages); + const [outputResult, outputStream] = + rpcTestUtils.streamToArray(); + const streamPair: ReadableWritablePair = { + readable: inputStream, + writable: outputStream, + }; + const server: ManifestItem = { + type: 'SERVER', + handler: async function* (input) { + yield input; + }, + }; + const rpcClient = await RPCClient.createRPCClient({ + manifest: { + server, + }, + streamPairCreateCallback: async () => streamPair, + logger, + }); + let count = 0; + await rpcClient.withMethods.server(params, async (output) => { + for await (const _ of output) count += 1; + }); + const result = await outputResult; + expect(count).toEqual(messages.length); + expect(result.toString()).toStrictEqual( + JSON.stringify({ + method: 'server', + jsonrpc: '2.0', + id: null, + params: params, + }), + ); + await rpcClient.destroy(); + }, + ); + testProp( + 'manifest withClient caller', + [ + rpcTestUtils.jsonRpcResponseResultArb(), + fc.array(fc.string(), { minLength: 2 }).noShrink(), + ], + async (message, inputMessages) => { + const inputStream = rpcTestUtils.jsonRpcStream([message]); + const [outputResult, outputStream] = + rpcTestUtils.streamToArray(); + const streamPair: ReadableWritablePair = { + readable: inputStream, + writable: outputStream, + }; + const client: ManifestItem = { + type: 'CLIENT', + handler: async (_) => { + return 'someValue'; + }, + }; + const rpcClient = await RPCClient.createRPCClient({ + manifest: { + client, + }, + streamPairCreateCallback: async () => streamPair, + logger, + }); + const result = await rpcClient.withMethods.client(async function* () { + for (const inputMessage of inputMessages) { + yield inputMessage; + } + }); + const expectedResult = inputMessages.map((v) => { + return JSON.stringify({ + method: 'client', + jsonrpc: '2.0', + id: null, + params: v, + }); + }); + expect((await outputResult).map((v) => v.toString())).toStrictEqual( + expectedResult, + ); + expect(result).toStrictEqual(message.result); + await rpcClient.destroy(); + }, + ); + testProp( + 'manifest withRaw caller', + [ + rpcTestUtils.safeJsonValueArb, + rpcTestUtils.rawDataArb, + rpcTestUtils.rawDataArb, + ], + async (headerParams, inputData, outputData) => { + const [inputResult, inputWritableStream] = + rpcTestUtils.streamToArray(); + const streamPair: ReadableWritablePair = { + readable: new ReadableStream({ + start: (controller) => { + for (const datum of outputData) { + controller.enqueue(datum); + } + controller.close(); + }, + }), + writable: inputWritableStream, + }; + const raw: ManifestItem = { + type: 'RAW', + handler: ([input]) => input, + }; + const rpcClient = await RPCClient.createRPCClient({ + manifest: { + raw, + }, + streamPairCreateCallback: async () => streamPair, + logger, + }); + const outputResult: Array = []; + await rpcClient.withMethods.raw(headerParams, async function* (output) { + for await (const outputValue of output) { + outputResult.push(outputValue); + } + for (const inputDatum of inputData) { + yield inputDatum; + } + }); + const expectedHeader: JsonRpcRequest = { + jsonrpc: '2.0', + method: 'raw', + params: headerParams, + id: null, + }; + expect(await inputResult).toStrictEqual([ + Buffer.from(JSON.stringify(expectedHeader)), + ...inputData, + ]); + expect(outputResult).toStrictEqual(outputData); + }, + ); + test('manifest without handler errors', async () => { + const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, + streamPairCreateCallback: async () => { + return {} as ReadableWritablePair; + }, + logger, + }); + // @ts-ignore: ignoring type safety here + expect(() => rpcClient.methods.someMethod()).toThrow(); + // @ts-ignore: ignoring type safety here + expect(() => rpcClient.withMethods.someMethod()).toThrow(); + await rpcClient.destroy(); + }); }); diff --git a/tests/clientRPC/handlers/agentStatus.test.ts b/tests/clientRPC/handlers/agentStatus.test.ts index 78b29a537..68fdc155e 100644 --- a/tests/clientRPC/handlers/agentStatus.test.ts +++ b/tests/clientRPC/handlers/agentStatus.test.ts @@ -96,6 +96,7 @@ describe('agentStatus', () => { logger.getChild('server'), ); const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, streamPairCreateCallback: async () => { return clientRPCUtils.startConnection( host, diff --git a/tests/clientRPC/websocket.test.ts b/tests/clientRPC/websocket.test.ts index 79d939953..818927c4a 100644 --- a/tests/clientRPC/websocket.test.ts +++ b/tests/clientRPC/websocket.test.ts @@ -74,6 +74,7 @@ describe('websocket', () => { // Setting up client const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, logger: logger.getChild('RPCClient'), streamPairCreateCallback: async () => { return clientRPCUtils.startConnection( From c8ebd74d257fedc456e8ba1e4964fe8dad245040 Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Tue, 7 Feb 2023 18:32:45 +1100 Subject: [PATCH 30/44] feat: static registering handlers using manifest - Related #500 - Related #501 [ci skip] --- src/RPC/RPCServer.ts | 61 ++- src/clientRPC/handlers/agentStatus.ts | 55 ++- src/clientRPC/handlers/agentUnlock.ts | 53 +-- tests/RPC/RPC.test.ts | 160 ++++---- tests/RPC/RPCServer.test.ts | 379 +++++++++++-------- tests/clientRPC/handlers/agentStatus.test.ts | 18 +- tests/clientRPC/handlers/agentUnlock.test.ts | 25 +- tests/clientRPC/websocket.test.ts | 35 +- 8 files changed, 432 insertions(+), 354 deletions(-) diff --git a/src/RPC/RPCServer.ts b/src/RPC/RPCServer.ts index c11a45c3a..d547f66f4 100644 --- a/src/RPC/RPCServer.ts +++ b/src/RPC/RPCServer.ts @@ -6,6 +6,7 @@ import type { JsonRpcResponse, JsonRpcResponseError, JsonRpcResponseResult, + Manifest, RawDuplexStreamHandler, ServerStreamHandler, UnaryHandler, @@ -21,20 +22,24 @@ import Logger from '@matrixai/logger'; import { PromiseCancellable } from '@matrixai/async-cancellable'; import * as rpcUtils from './utils'; import * as rpcErrors from './errors'; +import { never } from '../utils/utils'; import { sysexits } from '../errors'; interface RPCServer extends CreateDestroy {} @CreateDestroy() class RPCServer { static async createRPCServer({ + manifest, container, logger = new Logger(this.name), }: { + manifest: Manifest; container: POJO; logger?: Logger; }): Promise { logger.info(`Creating ${this.name}`); const rpcServer = new this({ + manifest, container, logger, }); @@ -50,12 +55,35 @@ class RPCServer { protected events: EventTarget = new EventTarget(); public constructor({ + manifest, container, logger, }: { + manifest: Manifest; container: POJO; logger: Logger; }) { + for (const [key, manifestItem] of Object.entries(manifest)) { + switch (manifestItem.type) { + case 'RAW': + this.registerRawStreamHandler(key, manifestItem.handler); + continue; + case 'DUPLEX': + this.registerDuplexStreamHandler(key, manifestItem.handler); + continue; + case 'SERVER': + this.registerServerStreamHandler(key, manifestItem.handler); + continue; + case 'CLIENT': + this.registerClientStreamHandler(key, manifestItem.handler); + continue; + case 'UNARY': + this.registerUnaryHandler(key, manifestItem.handler); + continue; + default: + never(); + } + } this.container = container; this.logger = logger; } @@ -72,19 +100,17 @@ class RPCServer { this.logger.info(`Destroyed ${this.constructor.name}`); } - @ready(new rpcErrors.ErrorRpcDestroyed()) - public registerRawStreamHandler( + protected registerRawStreamHandler( method: string, handler: RawDuplexStreamHandler, ) { this.handlerMap.set(method, handler); } - @ready(new rpcErrors.ErrorRpcDestroyed()) - public registerDuplexStreamHandler( - method: string, - handler: DuplexStreamHandler, - ) { + protected registerDuplexStreamHandler< + I extends JSONValue, + O extends JSONValue, + >(method: string, handler: DuplexStreamHandler) { // This needs to handle all the message parsing and conversion from // generators to the raw streams. @@ -186,8 +212,7 @@ class RPCServer { this.registerRawStreamHandler(method, rawSteamHandler); } - @ready(new rpcErrors.ErrorRpcDestroyed()) - public registerUnaryHandler( + protected registerUnaryHandler( method: string, handler: UnaryHandler, ) { @@ -205,11 +230,10 @@ class RPCServer { this.registerDuplexStreamHandler(method, wrapperDuplex); } - @ready(new rpcErrors.ErrorRpcDestroyed()) - public registerServerStreamHandler( - method: string, - handler: ServerStreamHandler, - ) { + protected registerServerStreamHandler< + I extends JSONValue, + O extends JSONValue, + >(method: string, handler: ServerStreamHandler) { const wrapperDuplex: DuplexStreamHandler = async function* ( input, container, @@ -224,11 +248,10 @@ class RPCServer { this.registerDuplexStreamHandler(method, wrapperDuplex); } - @ready(new rpcErrors.ErrorRpcDestroyed()) - public registerClientStreamHandler( - method: string, - handler: ClientStreamHandler, - ) { + protected registerClientStreamHandler< + I extends JSONValue, + O extends JSONValue, + >(method: string, handler: ClientStreamHandler) { const wrapperDuplex: DuplexStreamHandler = async function* ( input, container, diff --git a/src/clientRPC/handlers/agentStatus.ts b/src/clientRPC/handlers/agentStatus.ts index 2d14ba558..c708f16c7 100644 --- a/src/clientRPC/handlers/agentStatus.ts +++ b/src/clientRPC/handlers/agentStatus.ts @@ -1,10 +1,8 @@ -import type { UnaryHandler } from '../../RPC/types'; +import type { ManifestItem } from '../../RPC/types'; import type KeyRing from '../../keys/KeyRing'; import type CertManager from '../../keys/CertManager'; import type Logger from '@matrixai/logger'; import type { NodeIdEncoded } from '../../ids'; -import type RPCClient from '../../RPC/RPCClient'; -import type { POJO } from '../../types'; import * as nodesUtils from '../../nodes/utils'; import * as keysUtils from '../../keys/utils'; @@ -13,36 +11,27 @@ type StatusResult = { nodeId: NodeIdEncoded; publicJwk: string; }; -const agentStatusName = 'agentStatus'; -const agentStatusHandler: UnaryHandler = async ( - input, - container: { - keyRing: KeyRing; - certManager: CertManager; - logger: Logger; - }, - _connectionInfo, - _ctx, -) => { - return { - pid: process.pid, - nodeId: nodesUtils.encodeNodeId(container.keyRing.getNodeId()), - publicJwk: JSON.stringify( - keysUtils.publicKeyToJWK(container.keyRing.keyPair.publicKey), - ), - }; -}; -const agentStatusCaller = async (metadata: POJO, rpcClient: RPCClient) => { - const result = await rpcClient.unaryCaller( - agentStatusName, - null, - ); - return { - pid: result.pid, - nodeId: nodesUtils.decodeNodeId(result.nodeId), - publicJwk: result.publicJwk, - }; +const agentStatus: ManifestItem = { + type: 'UNARY', + handler: async ( + input, + container: { + keyRing: KeyRing; + certManager: CertManager; + logger: Logger; + }, + _connectionInfo, + _ctx, + ) => { + return { + pid: process.pid, + nodeId: nodesUtils.encodeNodeId(container.keyRing.getNodeId()), + publicJwk: JSON.stringify( + keysUtils.publicKeyToJWK(container.keyRing.keyPair.publicKey), + ), + }; + }, }; -export { agentStatusName, agentStatusHandler, agentStatusCaller }; +export { agentStatus }; diff --git a/src/clientRPC/handlers/agentUnlock.ts b/src/clientRPC/handlers/agentUnlock.ts index e1e77ad30..47c804024 100644 --- a/src/clientRPC/handlers/agentUnlock.ts +++ b/src/clientRPC/handlers/agentUnlock.ts @@ -1,41 +1,28 @@ -import type { UnaryHandler } from '../../RPC/types'; +import type { ManifestItem } from '../../RPC/types'; import type Logger from '@matrixai/logger'; -import type RPCClient from '../../RPC/RPCClient'; -import type { JSONValue } from '../../types'; import type { ClientDataAndMetadata } from '../types'; -const agentUnlockName = 'agentStatus'; -const agentUnlockHandler: UnaryHandler< +const agentUnlock: ManifestItem< ClientDataAndMetadata, ClientDataAndMetadata -> = async ( - _input, - _container: { - logger: Logger; +> = { + type: 'UNARY', + handler: async ( + _input, + _container: { + logger: Logger; + }, + _connectionInfo, + _ctx, + ) => { + // This is a NOP handler, + // authentication and unlocking is handled via middleware. + // Failure to authenticate will be an error from the middleware layer. + return { + metadata: {}, + data: null, + }; }, - _connectionInfo, - _ctx, -) => { - // This is a NOP handler, - // authentication and unlocking is handled via middleware. - // Failure to authenticate will be an error from the middleware layer. - return { - metadata: {}, - data: null, - }; }; -const agentUnlockCaller = async ( - metadata: Record, - rpcClient: RPCClient, -) => { - return rpcClient.unaryCaller< - ClientDataAndMetadata, - ClientDataAndMetadata - >(agentUnlockName, { - metadata: metadata, - data: null, - }); -}; - -export { agentUnlockName, agentUnlockHandler, agentUnlockCaller }; +export { agentUnlock }; diff --git a/tests/RPC/RPC.test.ts b/tests/RPC/RPC.test.ts index 03eb95d31..f7efb9369 100644 --- a/tests/RPC/RPC.test.ts +++ b/tests/RPC/RPC.test.ts @@ -1,11 +1,4 @@ -import type { - ClientStreamHandler, - DuplexStreamHandler, - JsonRpcRequest, - RawDuplexStreamHandler, - ServerStreamHandler, - UnaryHandler, -} from '@/RPC/types'; +import type { JsonRpcRequest, ManifestItem } from '@/RPC/types'; import type { ConnectionInfo } from '@/network/types'; import type { JSONValue } from '@/types'; import { fc, testProp } from '@fast-check/jest'; @@ -17,8 +10,6 @@ import * as rpcTestUtils from './utils'; describe('RPC', () => { const logger = new Logger(`RPC Test`, LogLevel.WARN, [new StreamHandler()]); - const methodName = 'testMethod'; - testProp( 'RPC communication with raw stream', [rpcTestUtils.rawDataArb], @@ -30,29 +21,32 @@ describe('RPC', () => { Uint8Array >(); - const container = {}; - const rpcServer = await RPCServer.createRPCServer({ container, logger }); let header: JsonRpcRequest | undefined; - const rawHandler: RawDuplexStreamHandler = ( - [input, header_], - _container, - _connectionInfo, - _ctx, - ) => { - header = header_; - return input; + const testMethod: ManifestItem = { + type: 'RAW', + handler: ([input, header_], _container, _connectionInfo, _ctx) => { + header = header_; + return input; + }, }; - - rpcServer.registerRawStreamHandler(methodName, rawHandler); + const manifest = { + testMethod, + }; + const container = {}; + const rpcServer = await RPCServer.createRPCServer({ + manifest, + container, + logger, + }); rpcServer.handleStream(serverPair, {} as ConnectionInfo); const rpcClient = await RPCClient.createRPCClient({ - manifest: {}, + manifest, streamPairCreateCallback: async () => clientPair, logger, }); - const callerInterface = await rpcClient.rawStreamCaller(methodName, { + const callerInterface = await rpcClient.methods.testMethod({ hello: 'world', }); const writer = callerInterface.writable.getWriter(); @@ -63,7 +57,7 @@ describe('RPC', () => { await writer.close(); const expectedHeader: JsonRpcRequest = { jsonrpc: '2.0', - method: methodName, + method: 'testMethod', params: { hello: 'world' }, id: null, }; @@ -83,26 +77,32 @@ describe('RPC', () => { Uint8Array >(); - const container = {}; - const rpcServer = await RPCServer.createRPCServer({ container, logger }); - - const duplexHandler: DuplexStreamHandler = - async function* (input, _container, _connectionInfo, _ctx) { + const testMethod: ManifestItem = { + type: 'DUPLEX', + handler: async function* (input, _container, _connectionInfo, _ctx) { for await (const val of input) { yield val; } - }; - - rpcServer.registerDuplexStreamHandler(methodName, duplexHandler); + }, + }; + const manifest = { + testMethod, + }; + const container = {}; + const rpcServer = await RPCServer.createRPCServer({ + manifest, + container, + logger, + }); rpcServer.handleStream(serverPair, {} as ConnectionInfo); const rpcClient = await RPCClient.createRPCClient({ - manifest: {}, + manifest, streamPairCreateCallback: async () => clientPair, logger, }); - const callerInterface = await rpcClient.duplexStreamCaller(methodName); + const callerInterface = await rpcClient.methods.testMethod(); const writer = callerInterface.writable.getWriter(); const reader = callerInterface.readable.getReader(); for (const value of values) { @@ -126,29 +126,32 @@ describe('RPC', () => { Uint8Array >(); - const container = {}; - const rpcServer = await RPCServer.createRPCServer({ container, logger }); - - const serverStreamHandler: ServerStreamHandler = - async function* (input, _container, _connectionInfo, _ctx) { + const testMethod: ManifestItem = { + type: 'SERVER', + handler: async function* (input, _container, _connectionInfo, _ctx) { for (let i = 0; i < input; i++) { yield i; } - }; - - rpcServer.registerServerStreamHandler(methodName, serverStreamHandler); + }, + }; + const manifest = { + testMethod, + }; + const container = {}; + const rpcServer = await RPCServer.createRPCServer({ + manifest, + container, + logger, + }); rpcServer.handleStream(serverPair, {} as ConnectionInfo); const rpcClient = await RPCClient.createRPCClient({ - manifest: {}, + manifest, streamPairCreateCallback: async () => clientPair, logger, }); - const callerInterface = await rpcClient.serverStreamCaller< - number, - number - >(methodName, value); + const callerInterface = await rpcClient.methods.testMethod(value); const outputs: Array = []; for await (const num of callerInterface) { @@ -168,31 +171,34 @@ describe('RPC', () => { Uint8Array >(); - const container = {}; - const rpcServer = await RPCServer.createRPCServer({ container, logger }); - - const clientStreamhandler: ClientStreamHandler = async ( - input, - ) => { - let acc = 0; - for await (const number of input) { - acc += number; - } - return acc; + const testMethod: ManifestItem = { + type: 'CLIENT', + handler: async (input) => { + let acc = 0; + for await (const number of input) { + acc += number; + } + return acc; + }, + }; + const manifest = { + testMethod, }; - rpcServer.registerClientStreamHandler(methodName, clientStreamhandler); + const container = {}; + const rpcServer = await RPCServer.createRPCServer({ + manifest, + container, + logger, + }); rpcServer.handleStream(serverPair, {} as ConnectionInfo); const rpcClient = await RPCClient.createRPCClient({ - manifest: {}, + manifest, streamPairCreateCallback: async () => clientPair, logger, }); - const callerInterface = await rpcClient.clientStreamCaller< - number, - number - >(methodName); + const callerInterface = await rpcClient.methods.testMethod(); const writer = callerInterface.writable.getWriter(); for (const value of values) { await writer.write(value); @@ -214,24 +220,28 @@ describe('RPC', () => { Uint8Array >(); + const testMethod: ManifestItem = { + type: 'UNARY', + handler: async (input) => input, + }; + const manifest = { + testMethod, + }; const container = {}; - const rpcServer = await RPCServer.createRPCServer({ container, logger }); - - const unaryCaller: UnaryHandler = async (input) => - input; - rpcServer.registerUnaryHandler(methodName, unaryCaller); + const rpcServer = await RPCServer.createRPCServer({ + manifest, + container, + logger, + }); rpcServer.handleStream(serverPair, {} as ConnectionInfo); const rpcClient = await RPCClient.createRPCClient({ - manifest: {}, + manifest, streamPairCreateCallback: async () => clientPair, logger, }); - const result = await rpcClient.unaryCaller( - methodName, - value, - ); + const result = await rpcClient.methods.testMethod(value); expect(result).toStrictEqual(value); await rpcServer.destroy(); await rpcClient.destroy(); diff --git a/tests/RPC/RPCServer.test.ts b/tests/RPC/RPCServer.test.ts index 8c772f1cc..5e2cdfaa2 100644 --- a/tests/RPC/RPCServer.test.ts +++ b/tests/RPC/RPCServer.test.ts @@ -1,13 +1,9 @@ import type { - ClientStreamHandler, - DuplexStreamHandler, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, JsonRpcResponseError, - RawDuplexStreamHandler, - ServerStreamHandler, - UnaryHandler, + ManifestItem, } from '@/RPC/types'; import type { JSONValue } from '@/types'; import type { ConnectionInfo, Host, Port } from '@/network/types'; @@ -67,34 +63,35 @@ describe(`${RPCServer.name}`, () => { .pipeThrough( new rpcTestUtils.BufferStreamToSnippedStream([4, 7, 13, 2, 6]), ); + const testMethod: ManifestItem = { + type: 'RAW', + handler: ([input]) => { + void (async () => { + for await (const _ of input) { + // No touch, only consume + } + })().catch(() => {}); + return new ReadableStream({ + start: (controller) => { + controller.enqueue(Buffer.from('hello world!')); + controller.close(); + }, + }); + }, + }; const container = {}; - const rpcServer = await RPCServer.createRPCServer({ container, logger }); + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod, + }, + container, + logger, + }); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); const readWriteStream: ReadableWritablePair = { readable: stream, writable: outputStream, }; - - const rawDuplexHandler: RawDuplexStreamHandler = ( - [input], - _container, - _connectionInfo, - _ctx, - ) => { - void (async () => { - for await (const _ of input) { - // No touch, only consume - } - })().catch(() => {}); - return new ReadableStream({ - start: (controller) => { - controller.enqueue(Buffer.from('hello world!')); - controller.close(); - }, - }); - }; - - rpcServer.registerRawStreamHandler(methodName, rawDuplexHandler); rpcServer.handleStream(readWriteStream, {} as ConnectionInfo); await outputResult; await rpcServer.destroy(); @@ -106,23 +103,28 @@ describe(`${RPCServer.name}`, () => { [specificMessageArb], async (messages) => { const stream = rpcTestUtils.jsonRpcStream(messages); + const testMethod: ManifestItem = { + type: 'DUPLEX', + handler: async function* (input, _container, _connectionInfo, _ctx) { + for await (const val of input) { + yield val; + break; + } + }, + }; const container = {}; - const rpcServer = await RPCServer.createRPCServer({ container, logger }); + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod, + }, + container, + logger, + }); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); const readWriteStream: ReadableWritablePair = { readable: stream, writable: outputStream, }; - - const duplexHandler: DuplexStreamHandler = - async function* (input, _container, _connectionInfo, _ctx) { - for await (const val of input) { - yield val; - break; - } - }; - - rpcServer.registerDuplexStreamHandler(methodName, duplexHandler); rpcServer.handleStream(readWriteStream, {} as ConnectionInfo); await outputResult; await rpcServer.destroy(); @@ -133,24 +135,29 @@ describe(`${RPCServer.name}`, () => { [specificMessageArb], async (messages) => { const stream = rpcTestUtils.jsonRpcStream(messages); + const testMethod: ManifestItem = { + type: 'CLIENT', + handler: async function (input, _container, _connectionInfo, _ctx) { + let count = 0; + for await (const _ of input) { + count += 1; + } + return count; + }, + }; const container = {}; - const rpcServer = await RPCServer.createRPCServer({ container, logger }); + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod, + }, + container, + logger, + }); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); const readWriteStream: ReadableWritablePair = { readable: stream, writable: outputStream, }; - - const clientHandler: ClientStreamHandler = - async function (input, _container, _connectionInfo, _ctx) { - let count = 0; - for await (const _ of input) { - count += 1; - } - return count; - }; - - rpcServer.registerClientStreamHandler(methodName, clientHandler); rpcServer.handleStream(readWriteStream, {} as ConnectionInfo); await outputResult; await rpcServer.destroy(); @@ -161,22 +168,27 @@ describe(`${RPCServer.name}`, () => { [singleNumberMessageArb], async (messages) => { const stream = rpcTestUtils.jsonRpcStream(messages); + const testMethod: ManifestItem = { + type: 'SERVER', + handler: async function* (input, _container, _connectionInfo, _ctx) { + for (let i = 0; i < input; i++) { + yield i; + } + }, + }; const container = {}; - const rpcServer = await RPCServer.createRPCServer({ container, logger }); + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod, + }, + container, + logger, + }); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); const readWriteStream: ReadableWritablePair = { readable: stream, writable: outputStream, }; - - const serverHandler: ServerStreamHandler = - async function* (input, _container, _connectionInfo, _ctx) { - for (let i = 0; i < input; i++) { - yield i; - } - }; - - rpcServer.registerServerStreamHandler(methodName, serverHandler); rpcServer.handleStream(readWriteStream, {} as ConnectionInfo); await outputResult; await rpcServer.destroy(); @@ -187,24 +199,23 @@ describe(`${RPCServer.name}`, () => { [specificMessageArb], async (messages) => { const stream = rpcTestUtils.jsonRpcStream(messages); + const testMethod: ManifestItem = { + type: 'UNARY', + handler: async (input, _container, _connectionInfo, _ctx) => input, + }; const container = {}; - const rpcServer = await RPCServer.createRPCServer({ container, logger }); + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod, + }, + container, + logger, + }); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); const readWriteStream: ReadableWritablePair = { readable: stream, writable: outputStream, }; - - const unaryHandler: UnaryHandler = async function ( - input, - _container, - _connectionInfo, - _ctx, - ) { - return input; - }; - - rpcServer.registerUnaryHandler(methodName, unaryHandler); rpcServer.handleStream(readWriteStream, {} as ConnectionInfo); await outputResult; await rpcServer.destroy(); @@ -215,27 +226,32 @@ describe(`${RPCServer.name}`, () => { [specificMessageArb], async (messages) => { const stream = rpcTestUtils.jsonRpcStream(messages); + const testMethod: ManifestItem = { + type: 'DUPLEX', + handler: async function* (input, container_, _connectionInfo, _ctx) { + expect(container_).toBe(container); + for await (const val of input) { + yield val; + } + }, + }; const container = { a: Symbol('a'), B: Symbol('b'), C: Symbol('c'), }; - const rpcServer = await RPCServer.createRPCServer({ container, logger }); + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod, + }, + container, + logger, + }); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); const readWriteStream: ReadableWritablePair = { readable: stream, writable: outputStream, }; - - const duplexHandler: DuplexStreamHandler = - async function* (input, container_, _connectionInfo, _ctx) { - expect(container_).toBe(container); - for await (const val of input) { - yield val; - } - }; - - rpcServer.registerDuplexStreamHandler(methodName, duplexHandler); rpcServer.handleStream(readWriteStream, {} as ConnectionInfo); await outputResult; await rpcServer.destroy(); @@ -254,23 +270,29 @@ describe(`${RPCServer.name}`, () => { remoteNodeId: 'asd' as unknown as NodeId, remotePort: 12341 as Port, }; + let handledConnectionInfo; + const testMethod: ManifestItem = { + type: 'DUPLEX', + handler: async function* (input, _container, connectionInfo_, _ctx) { + handledConnectionInfo = connectionInfo_; + for await (const val of input) { + yield val; + } + }, + }; const container = {}; - const rpcServer = await RPCServer.createRPCServer({ container, logger }); + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod, + }, + container, + logger, + }); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); const readWriteStream: ReadableWritablePair = { readable: stream, writable: outputStream, }; - - let handledConnectionInfo; - const duplexHandler: DuplexStreamHandler = - async function* (input, _container, connectionInfo_, _ctx) { - handledConnectionInfo = connectionInfo_; - for await (const val of input) { - yield val; - } - }; - rpcServer.registerDuplexStreamHandler(methodName, duplexHandler); rpcServer.handleStream(readWriteStream, connectionInfo); await outputResult; await rpcServer.destroy(); @@ -284,8 +306,23 @@ describe(`${RPCServer.name}`, () => { [specificMessageArb], async (messages) => { const stream = rpcTestUtils.jsonRpcStream(messages); + const testMethod: ManifestItem = { + type: 'DUPLEX', + handler: async function* (input, _container, _connectionInfo, ctx) { + for await (const val of input) { + if (ctx.signal.aborted) throw ctx.signal.reason; + yield val; + } + }, + }; const container = {}; - const rpcServer = await RPCServer.createRPCServer({ container, logger }); + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod, + }, + container, + logger, + }); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); let thing; let lastMessage: JsonRpcMessage | undefined; @@ -307,15 +344,6 @@ describe(`${RPCServer.name}`, () => { readable: stream, writable: tapStream.writable, }; - - const duplexHandler: DuplexStreamHandler = - async function* (input, _container, _connectionInfo, ctx) { - for await (const val of input) { - if (ctx.signal.aborted) throw ctx.signal.reason; - yield val; - } - }; - rpcServer.registerDuplexStreamHandler(methodName, duplexHandler); rpcServer.handleStream(readWriteStream, {} as ConnectionInfo); await outputResult; await expect(thing).toResolve(); @@ -326,22 +354,27 @@ describe(`${RPCServer.name}`, () => { ); testProp('Handler yields nothing', [specificMessageArb], async (messages) => { const stream = rpcTestUtils.jsonRpcStream(messages); + const testMethod: ManifestItem = { + type: 'DUPLEX', + handler: async function* (input, _container, _connectionInfo, _ctx) { + for await (const _ of input) { + // Do nothing, just consume + } + }, + }; const container = {}; - const rpcServer = await RPCServer.createRPCServer({ container, logger }); + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod, + }, + container, + logger, + }); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); const readWriteStream: ReadableWritablePair = { readable: stream, writable: outputStream, }; - - const duplexHandler: DuplexStreamHandler = - async function* (input, _container, _connectionInfo, _ctx) { - for await (const _ of input) { - // Do nothing, just consume - } - }; - - rpcServer.registerDuplexStreamHandler(methodName, duplexHandler); rpcServer.handleStream(readWriteStream, {} as ConnectionInfo); await outputResult; // We're just expecting no errors @@ -352,8 +385,20 @@ describe(`${RPCServer.name}`, () => { [specificMessageArb, errorArb], async (messages, error) => { const stream = rpcTestUtils.jsonRpcStream(messages); + const testMethod: ManifestItem = { + type: 'DUPLEX', + handler: async function* (_input, _container, _connectionInfo, _ctx) { + throw error; + }, + }; const container = {}; - const rpcServer = await RPCServer.createRPCServer({ container, logger }); + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod, + }, + container, + logger, + }); let resolve, reject; const errorProm = new Promise((resolve_, reject_) => { resolve = resolve_; @@ -367,13 +412,6 @@ describe(`${RPCServer.name}`, () => { readable: stream, writable: outputStream, }; - - const duplexHandler: DuplexStreamHandler = - async function* (_input, _container, _connectionInfo, _ctx) { - throw error; - }; - - rpcServer.registerDuplexStreamHandler(methodName, duplexHandler); rpcServer.handleStream(readWriteStream, {} as ConnectionInfo); const errorMessage = JSON.parse((await outputResult)[0]!.toString()); expect(errorMessage.error.code).toEqual(error.exitCode); @@ -388,8 +426,20 @@ describe(`${RPCServer.name}`, () => { [specificMessageArb], async (messages) => { const stream = rpcTestUtils.jsonRpcStream(messages); + const testMethod: ManifestItem = { + type: 'DUPLEX', + handler: async function* (_input, _container, _connectionInfo, _ctx) { + throw new rpcErrors.ErrorRpcPlaceholderConnectionError(); + }, + }; const container = {}; - const rpcServer = await RPCServer.createRPCServer({ container, logger }); + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod, + }, + container, + logger, + }); let resolve, reject; const errorProm = new Promise((resolve_, reject_) => { resolve = resolve_; @@ -403,13 +453,6 @@ describe(`${RPCServer.name}`, () => { readable: stream, writable: outputStream, }; - - const duplexHandler: DuplexStreamHandler = - async function* (_input, _container, _connectionInfo, _ctx) { - throw new rpcErrors.ErrorRpcPlaceholderConnectionError(); - }; - - rpcServer.registerDuplexStreamHandler(methodName, duplexHandler); rpcServer.handleStream(readWriteStream, {} as ConnectionInfo); await outputResult; @@ -420,22 +463,27 @@ describe(`${RPCServer.name}`, () => { ); testProp('forward middlewares', [specificMessageArb], async (messages) => { const stream = rpcTestUtils.jsonRpcStream(messages); + const testMethod: ManifestItem = { + type: 'DUPLEX', + handler: async function* (input, _container, _connectionInfo, _ctx) { + for await (const val of input) { + yield val; + } + }, + }; const container = {}; - const rpcServer = await RPCServer.createRPCServer({ container, logger }); + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod, + }, + container, + logger, + }); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); const readWriteStream: ReadableWritablePair = { readable: stream, writable: outputStream, }; - - const duplexHandler: DuplexStreamHandler = - async function* (input, _container, _connectionInfo, _ctx) { - for await (const val of input) { - yield val; - } - }; - - rpcServer.registerDuplexStreamHandler(methodName, duplexHandler); rpcServer.registerMiddleware(() => { return { forward: new TransformStream({ @@ -462,22 +510,27 @@ describe(`${RPCServer.name}`, () => { }); testProp('reverse middlewares', [specificMessageArb], async (messages) => { const stream = rpcTestUtils.jsonRpcStream(messages); + const testMethod: ManifestItem = { + type: 'DUPLEX', + handler: async function* (input, _container, _connectionInfo, _ctx) { + for await (const val of input) { + yield val; + } + }, + }; const container = {}; - const rpcServer = await RPCServer.createRPCServer({ container, logger }); + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod, + }, + container, + logger, + }); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); const readWriteStream: ReadableWritablePair = { readable: stream, writable: outputStream, }; - - const duplexHandler: DuplexStreamHandler = - async function* (input, _container, _connectionInfo, _ctx) { - for await (const val of input) { - yield val; - } - }; - - rpcServer.registerDuplexStreamHandler(methodName, duplexHandler); rpcServer.registerMiddleware(() => { return { forward: new TransformStream(), @@ -507,23 +560,27 @@ describe(`${RPCServer.name}`, () => { [invalidTokenMessageArb], async (message) => { const stream = rpcTestUtils.jsonRpcStream([message]); + const testMethod: ManifestItem = { + type: 'DUPLEX', + handler: async function* (input, _container, _connectionInfo, _ctx) { + for await (const val of input) { + yield val; + } + }, + }; const container = {}; - const rpcServer = await RPCServer.createRPCServer({ container, logger }); + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod, + }, + container, + logger, + }); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); const readWriteStream: ReadableWritablePair = { readable: stream, writable: outputStream, }; - - const duplexHandler: DuplexStreamHandler = - async function* (input, _container, _connectionInfo, _ctx) { - for await (const val of input) { - yield val; - } - }; - - rpcServer.registerDuplexStreamHandler(methodName, duplexHandler); - type TestType = { metadata: { token: string; diff --git a/tests/clientRPC/handlers/agentStatus.test.ts b/tests/clientRPC/handlers/agentStatus.test.ts index 68fdc155e..4b94d381f 100644 --- a/tests/clientRPC/handlers/agentStatus.test.ts +++ b/tests/clientRPC/handlers/agentStatus.test.ts @@ -11,13 +11,10 @@ import * as keysUtils from '@/keys/utils'; import RPCServer from '@/RPC/RPCServer'; import TaskManager from '@/tasks/TaskManager'; import CertManager from '@/keys/CertManager'; -import { - agentStatusName, - agentStatusHandler, - agentStatusCaller, -} from '@/clientRPC/handlers/agentStatus'; +import { agentStatus } from '@/clientRPC/handlers/agentStatus'; import RPCClient from '@/RPC/RPCClient'; import * as clientRPCUtils from '@/clientRPC/utils'; +import * as nodesUtils from '@/nodes/utils'; import * as testsUtils from '../../utils'; describe('agentStatus', () => { @@ -81,7 +78,11 @@ describe('agentStatus', () => { }); test('get status %s', async () => { // Setup + const manifest = { + agentStatus, + }; const rpcServer = await RPCServer.createRPCServer({ + manifest, container: { keyRing, certManager, @@ -89,14 +90,13 @@ describe('agentStatus', () => { }, logger: logger.getChild('RPCServer'), }); - rpcServer.registerUnaryHandler(agentStatusName, agentStatusHandler); wss = clientRPCUtils.createClientServer( server, rpcServer, logger.getChild('server'), ); const rpcClient = await RPCClient.createRPCClient({ - manifest: {}, + manifest, streamPairCreateCallback: async () => { return clientRPCUtils.startConnection( host, @@ -107,10 +107,10 @@ describe('agentStatus', () => { logger: logger.getChild('RPCClient'), }); // Doing the test - const result = await agentStatusCaller({}, rpcClient); + const result = await rpcClient.methods.agentStatus(null); expect(result).toStrictEqual({ pid: process.pid, - nodeId: keyRing.getNodeId(), + nodeId: nodesUtils.encodeNodeId(keyRing.getNodeId()), publicJwk: JSON.stringify( keysUtils.publicKeyToJWK(keyRing.keyPair.publicKey), ), diff --git a/tests/clientRPC/handlers/agentUnlock.test.ts b/tests/clientRPC/handlers/agentUnlock.test.ts index 5e6cdcf45..aeca08758 100644 --- a/tests/clientRPC/handlers/agentUnlock.test.ts +++ b/tests/clientRPC/handlers/agentUnlock.test.ts @@ -11,11 +11,7 @@ import * as keysUtils from '@/keys/utils'; import RPCServer from '@/RPC/RPCServer'; import TaskManager from '@/tasks/TaskManager'; import CertManager from '@/keys/CertManager'; -import { - agentUnlockName, - agentUnlockHandler, - agentUnlockCaller, -} from '@/clientRPC/handlers/agentUnlock'; +import { agentUnlock } from '@/clientRPC/handlers/agentUnlock'; import RPCClient from '@/RPC/RPCClient'; import { Session, SessionManager } from '@/sessions'; import * as abcUtils from '@/clientRPC/utils'; @@ -94,13 +90,16 @@ describe('agentUnlock', () => { }); test('get status', async () => { // Setup + const manifest = { + agentUnlock, + }; const rpcServer = await RPCServer.createRPCServer({ + manifest, container: { logger, }, logger, }); - rpcServer.registerUnaryHandler(agentUnlockName, agentUnlockHandler); rpcServer.registerMiddleware( abcUtils.authenticationMiddlewareServer(sessionManager, keyRing), ); @@ -110,6 +109,7 @@ describe('agentUnlock', () => { logger.getChild('server'), ); const rpcClient = await RPCClient.createRPCClient({ + manifest, streamPairCreateCallback: async () => { return clientRPCUtils.startConnection( '127.0.0.1', @@ -124,19 +124,22 @@ describe('agentUnlock', () => { ); // Doing the test - const result = await agentUnlockCaller( - { + const result = await rpcClient.methods.agentUnlock({ + metadata: { Authorization: abcUtils.encodeAuthFromPassword(password), }, - rpcClient, - ); + data: null, + }); expect(result).toMatchObject({ metadata: { Authorization: expect.any(String), }, data: null, }); - const result2 = await agentUnlockCaller({}, rpcClient); + const result2 = await rpcClient.methods.agentUnlock({ + metadata: {}, + data: null, + }); expect(result2).toMatchObject({ metadata: { Authorization: expect.any(String), diff --git a/tests/clientRPC/websocket.test.ts b/tests/clientRPC/websocket.test.ts index 818927c4a..e66526bfc 100644 --- a/tests/clientRPC/websocket.test.ts +++ b/tests/clientRPC/websocket.test.ts @@ -1,6 +1,8 @@ import type { TLSConfig } from '@/network/types'; import type { Server } from 'https'; import type { WebSocketServer } from 'ws'; +import type { ManifestItem } from '@/RPC/types'; +import type { JSONValue } from '@/types'; import fs from 'fs'; import path from 'path'; import os from 'os'; @@ -52,20 +54,27 @@ describe('websocket', () => { test('websocket should work with RPC', async () => { // Setting up server + const test1: ManifestItem = { + type: 'UNARY', + handler: async (params, _container, _connectionInfo) => { + return params; + }, + }; + const test2: ManifestItem = { + type: 'UNARY', + handler: async () => { + return { hello: 'not world' }; + }, + }; + const manifest = { + test1, + test2, + }; const rpcServer = new RPCServer({ + manifest, container: {}, logger: logger.getChild('RPCServer'), }); - rpcServer.registerUnaryHandler( - 'test1', - async (params, _container, _connectionInfo) => { - return params; - }, - ); - rpcServer.registerUnaryHandler('test2', async () => { - return { hello: 'not world' }; - }); - wss = clientRPCUtils.createClientServer( server, rpcServer, @@ -74,7 +83,7 @@ describe('websocket', () => { // Setting up client const rpcClient = await RPCClient.createRPCClient({ - manifest: {}, + manifest, logger: logger.getChild('RPCClient'), streamPairCreateCallback: async () => { return clientRPCUtils.startConnection( @@ -87,10 +96,10 @@ describe('websocket', () => { // Making the call await expect( - rpcClient.unaryCaller('test1', { hello: 'world2' }), + rpcClient.methods.test1({ hello: 'world2' }), ).resolves.toStrictEqual({ hello: 'world2' }); await expect( - rpcClient.unaryCaller('test2', { hello: 'world2' }), + rpcClient.methods.test2({ hello: 'world2' }), ).resolves.toStrictEqual({ hello: 'not world' }); await expect( rpcClient.unaryCaller('test3', { hello: 'world2' }), From fe0a0b72f05cc9951f6b7517bc17e170901e0b0e Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Tue, 7 Feb 2023 20:34:42 +1100 Subject: [PATCH 31/44] feat: static middleware for server - Related #500 [ci skip] --- src/RPC/RPCClient.ts | 9 +- src/RPC/RPCServer.ts | 64 +++++------- src/RPC/types.ts | 8 +- src/RPC/utils.ts | 56 ++++++++++ src/clientRPC/utils.ts | 4 + tests/RPC/RPCServer.test.ts | 102 ++++++++++--------- tests/clientRPC/handlers/agentUnlock.test.ts | 12 +-- tests/clientRPC/websocket.test.ts | 2 +- 8 files changed, 160 insertions(+), 97 deletions(-) diff --git a/src/RPC/RPCClient.ts b/src/RPC/RPCClient.ts index 06513e0e3..b515baa9d 100644 --- a/src/RPC/RPCClient.ts +++ b/src/RPC/RPCClient.ts @@ -299,13 +299,20 @@ class RPCClient { } protected middleware: Array< - MiddlewareFactory, JsonRpcResponse> + MiddlewareFactory< + JsonRpcRequest, + JsonRpcRequest, + JsonRpcResponse, + JsonRpcResponse + > > = []; @ready(new rpcErrors.ErrorRpcDestroyed()) public registerMiddleware( middlewareFactory: MiddlewareFactory< JsonRpcRequest, + JsonRpcRequest, + JsonRpcResponse, JsonRpcResponse >, ) { diff --git a/src/RPC/RPCServer.ts b/src/RPC/RPCServer.ts index d547f66f4..81a8ea4b7 100644 --- a/src/RPC/RPCServer.ts +++ b/src/RPC/RPCServer.ts @@ -31,16 +31,24 @@ class RPCServer { static async createRPCServer({ manifest, container, + middleware = rpcUtils.defaultMiddlewareWrapper(), logger = new Logger(this.name), }: { manifest: Manifest; container: POJO; + middleware?: MiddlewareFactory< + JsonRpcRequest, + Uint8Array, + Uint8Array, + JsonRpcResponseResult + >; logger?: Logger; }): Promise { logger.info(`Creating ${this.name}`); const rpcServer = new this({ manifest, container, + middleware, logger, }); logger.info(`Created ${this.name}`); @@ -53,14 +61,27 @@ class RPCServer { protected handlerMap: Map = new Map(); protected activeStreams: Set> = new Set(); protected events: EventTarget = new EventTarget(); + protected middleware: MiddlewareFactory< + JsonRpcRequest, + Uint8Array, + Uint8Array, + JsonRpcResponseResult + >; public constructor({ manifest, container, + middleware, logger, }: { manifest: Manifest; container: POJO; + middleware: MiddlewareFactory< + JsonRpcRequest, + Uint8Array, + Uint8Array, + JsonRpcResponseResult + >; logger: Logger; }) { for (const [key, manifestItem] of Object.entries(manifest)) { @@ -85,6 +106,7 @@ class RPCServer { } } this.container = container; + this.middleware = middleware; this.logger = logger; } @@ -120,23 +142,10 @@ class RPCServer { connectionInfo, ctx, ) => { - // Middleware - const outputTransformStream = new rpcUtils.JsonMessageToJsonStream(); - const outputReadableSteam = outputTransformStream.readable; - let forwardStream = input.pipeThrough( - new rpcUtils.JsonToJsonMessageStream( - rpcUtils.parseJsonRpcRequest, - undefined, - header, - ), - ); - let reverseStream = outputTransformStream.writable; - for (const middlewareFactory of this.middleware) { - const middleware = middlewareFactory(); - forwardStream = forwardStream.pipeThrough(middleware.forward); - void middleware.reverse.readable.pipeTo(reverseStream).catch(() => {}); - reverseStream = middleware.reverse.writable; - } + // Setting up middleware + const middleware = this.middleware(header); + const forwardStream = input.pipeThrough(middleware.forward); + const reverseStream = middleware.reverse.writable; const events = this.events; const outputGen = async function* (): AsyncGenerator< JsonRpcResponse @@ -206,7 +215,7 @@ class RPCServer { }); void reverseMiddlewareStream.pipeTo(reverseStream).catch(() => {}); - return outputReadableSteam; + return middleware.reverse.readable; }; this.registerRawStreamHandler(method, rawSteamHandler); @@ -344,25 +353,6 @@ class RPCServer { ) { this.events.removeEventListener(type, callback, options); } - - protected middleware: Array< - MiddlewareFactory, JsonRpcResponse> - > = []; - - @ready(new rpcErrors.ErrorRpcDestroyed()) - public registerMiddleware( - middlewareFactory: MiddlewareFactory< - JsonRpcRequest, - JsonRpcResponse - >, - ) { - this.middleware.push(middlewareFactory); - } - - @ready(new rpcErrors.ErrorRpcDestroyed()) - public clearMiddleware() { - this.middleware = []; - } } export default RPCServer; diff --git a/src/RPC/types.ts b/src/RPC/types.ts index 1995030dd..312e65973 100644 --- a/src/RPC/types.ts +++ b/src/RPC/types.ts @@ -135,9 +135,11 @@ type StreamPairCreateCallback = () => Promise< ReadableWritablePair >; -type MiddlewareFactory = () => { - forward: ReadableWritablePair; - reverse: ReadableWritablePair; +type MiddlewareFactory = ( + header?: JsonRpcRequest, +) => { + forward: ReadableWritablePair; + reverse: ReadableWritablePair; }; type DuplexStreamCaller< diff --git a/src/RPC/utils.ts b/src/RPC/utils.ts index 9a05e9b28..9cb9479ce 100644 --- a/src/RPC/utils.ts +++ b/src/RPC/utils.ts @@ -13,6 +13,7 @@ import type { JsonRpcResponseResult, JsonRpcRequest, JsonRpcResponse, + MiddlewareFactory, } from 'RPC/types'; import type { JSONValue } from '../types'; import type { JsonValue } from 'fast-check'; @@ -637,6 +638,59 @@ function getHandlerTypes(manifest: Manifest): Record { return out; } +const defaultMiddleware: MiddlewareFactory< + JsonRpcRequest, + JsonRpcRequest, + JsonRpcResponse, + JsonRpcResponse +> = () => { + return { + forward: new TransformStream(), + reverse: new TransformStream(), + }; +}; + +const defaultMiddlewareWrapper = ( + middleware: MiddlewareFactory< + JsonRpcRequest, + JsonRpcRequest, + JsonRpcResponse, + JsonRpcResponse + > = defaultMiddleware, +) => { + return (header: JsonRpcRequest) => { + const inputTransformStream = new JsonToJsonMessageStream( + parseJsonRpcRequest, + undefined, + header, + ); + const outputTransformStream = new TransformStream< + JsonRpcResponseResult, + JsonRpcResponseResult + >(); + + const middleMiddleware = middleware(header); + + const forwardReadable = inputTransformStream.readable.pipeThrough( + middleMiddleware.forward, + ); // Usual middleware here + const reverseReadable = outputTransformStream.readable + .pipeThrough(middleMiddleware.reverse) // Usual middleware here + .pipeThrough(new JsonMessageToJsonStream()); + + return { + forward: { + readable: forwardReadable, + writable: inputTransformStream.writable, + }, + reverse: { + readable: reverseReadable, + writable: outputTransformStream.writable, + }, + }; + }; +}; + export { JsonToJsonMessageStream, JsonMessageToJsonStream, @@ -657,4 +711,6 @@ export { QueueMergingTransformStream, extractFirstMessageTransform, getHandlerTypes, + defaultMiddleware, + defaultMiddlewareWrapper, }; diff --git a/src/clientRPC/utils.ts b/src/clientRPC/utils.ts index 631825de3..2a7d074d0 100644 --- a/src/clientRPC/utils.ts +++ b/src/clientRPC/utils.ts @@ -77,6 +77,8 @@ function authenticationMiddlewareServer( keyRing: KeyRing, ): MiddlewareFactory< JsonRpcRequest>, + JsonRpcRequest>, + JsonRpcResponse>, JsonRpcResponse> > { return () => { @@ -127,6 +129,8 @@ function authenticationMiddlewareClient( session: Session, ): MiddlewareFactory< JsonRpcRequest>, + JsonRpcRequest>, + JsonRpcResponse>, JsonRpcResponse> > { return () => { diff --git a/tests/RPC/RPCServer.test.ts b/tests/RPC/RPCServer.test.ts index 5e2cdfaa2..c4f35b6dc 100644 --- a/tests/RPC/RPCServer.test.ts +++ b/tests/RPC/RPCServer.test.ts @@ -14,6 +14,7 @@ import { fc, testProp } from '@fast-check/jest'; import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; import RPCServer from '@/RPC/RPCServer'; import * as rpcErrors from '@/RPC/errors'; +import * as rpcUtils from '@/RPC/utils'; import * as rpcTestUtils from './utils'; describe(`${RPCServer.name}`, () => { @@ -472,10 +473,22 @@ describe(`${RPCServer.name}`, () => { }, }; const container = {}; + const middleware = rpcUtils.defaultMiddlewareWrapper(() => { + return { + forward: new TransformStream({ + transform: (chunk, controller) => { + chunk.params = 1; + controller.enqueue(chunk); + }, + }), + reverse: new TransformStream(), + }; + }); const rpcServer = await RPCServer.createRPCServer({ manifest: { testMethod, }, + middleware, container, logger, }); @@ -484,17 +497,6 @@ describe(`${RPCServer.name}`, () => { readable: stream, writable: outputStream, }; - rpcServer.registerMiddleware(() => { - return { - forward: new TransformStream({ - transform: (chunk, controller) => { - chunk.params = 1; - controller.enqueue(chunk); - }, - }), - reverse: new TransformStream(), - }; - }); rpcServer.handleStream(readWriteStream, {} as ConnectionInfo); const out = await outputResult; expect(out.map((v) => v!.toString())).toStrictEqual( @@ -519,10 +521,22 @@ describe(`${RPCServer.name}`, () => { }, }; const container = {}; + const middleware = rpcUtils.defaultMiddlewareWrapper(() => { + return { + forward: new TransformStream(), + reverse: new TransformStream({ + transform: (chunk, controller) => { + if ('result' in chunk) chunk.result = 1; + controller.enqueue(chunk); + }, + }), + }; + }); const rpcServer = await RPCServer.createRPCServer({ manifest: { testMethod, }, + middleware, container, logger, }); @@ -531,17 +545,6 @@ describe(`${RPCServer.name}`, () => { readable: stream, writable: outputStream, }; - rpcServer.registerMiddleware(() => { - return { - forward: new TransformStream(), - reverse: new TransformStream({ - transform: (chunk, controller) => { - if ('result' in chunk) chunk.result = 1; - controller.enqueue(chunk); - }, - }), - }; - }); rpcServer.handleStream(readWriteStream, {} as ConnectionInfo); const out = await outputResult; expect(out.map((v) => v!.toString())).toStrictEqual( @@ -569,33 +572,7 @@ describe(`${RPCServer.name}`, () => { }, }; const container = {}; - const rpcServer = await RPCServer.createRPCServer({ - manifest: { - testMethod, - }, - container, - logger, - }); - const [outputResult, outputStream] = rpcTestUtils.streamToArray(); - const readWriteStream: ReadableWritablePair = { - readable: stream, - writable: outputStream, - }; - type TestType = { - metadata: { - token: string; - }; - data: JSONValue; - }; - const failureMessage: JsonRpcResponseError = { - jsonrpc: '2.0', - id: null, - error: { - code: 1, - message: 'failure of somekind', - }, - }; - rpcServer.registerMiddleware(() => { + const middleware = rpcUtils.defaultMiddlewareWrapper(() => { let first = true; let reverseController: TransformStreamDefaultController< JsonRpcResponse @@ -627,6 +604,33 @@ describe(`${RPCServer.name}`, () => { }), }; }); + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod, + }, + middleware, + container, + logger, + }); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const readWriteStream: ReadableWritablePair = { + readable: stream, + writable: outputStream, + }; + type TestType = { + metadata: { + token: string; + }; + data: JSONValue; + }; + const failureMessage: JsonRpcResponseError = { + jsonrpc: '2.0', + id: null, + error: { + code: 1, + message: 'failure of somekind', + }, + }; rpcServer.handleStream(readWriteStream, {} as ConnectionInfo); expect((await outputResult).toString()).toEqual( JSON.stringify(failureMessage), diff --git a/tests/clientRPC/handlers/agentUnlock.test.ts b/tests/clientRPC/handlers/agentUnlock.test.ts index aeca08758..4397df001 100644 --- a/tests/clientRPC/handlers/agentUnlock.test.ts +++ b/tests/clientRPC/handlers/agentUnlock.test.ts @@ -14,8 +14,8 @@ import CertManager from '@/keys/CertManager'; import { agentUnlock } from '@/clientRPC/handlers/agentUnlock'; import RPCClient from '@/RPC/RPCClient'; import { Session, SessionManager } from '@/sessions'; -import * as abcUtils from '@/clientRPC/utils'; import * as clientRPCUtils from '@/clientRPC/utils'; +import * as rpcUtils from '@/RPC/utils'; import * as testsUtils from '../../utils'; describe('agentUnlock', () => { @@ -95,14 +95,14 @@ describe('agentUnlock', () => { }; const rpcServer = await RPCServer.createRPCServer({ manifest, + middleware: rpcUtils.defaultMiddlewareWrapper( + clientRPCUtils.authenticationMiddlewareServer(sessionManager, keyRing), + ), container: { logger, }, logger, }); - rpcServer.registerMiddleware( - abcUtils.authenticationMiddlewareServer(sessionManager, keyRing), - ); wss = clientRPCUtils.createClientServer( server, rpcServer, @@ -120,13 +120,13 @@ describe('agentUnlock', () => { logger, }); rpcClient.registerMiddleware( - abcUtils.authenticationMiddlewareClient(session), + clientRPCUtils.authenticationMiddlewareClient(session), ); // Doing the test const result = await rpcClient.methods.agentUnlock({ metadata: { - Authorization: abcUtils.encodeAuthFromPassword(password), + Authorization: clientRPCUtils.encodeAuthFromPassword(password), }, data: null, }); diff --git a/tests/clientRPC/websocket.test.ts b/tests/clientRPC/websocket.test.ts index e66526bfc..9b1c64380 100644 --- a/tests/clientRPC/websocket.test.ts +++ b/tests/clientRPC/websocket.test.ts @@ -70,7 +70,7 @@ describe('websocket', () => { test1, test2, }; - const rpcServer = new RPCServer({ + const rpcServer = await RPCServer.createRPCServer({ manifest, container: {}, logger: logger.getChild('RPCServer'), From 6cceedaf06ed667c690547a59d8b6c301379537c Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Wed, 8 Feb 2023 11:30:40 +1100 Subject: [PATCH 32/44] fix: general cleaning up [ci skip] --- src/RPC/RPCClient.ts | 29 ++-- src/RPC/RPCServer.ts | 130 +++++++-------- src/RPC/errors.ts | 29 +--- src/RPC/types.ts | 120 +++++++------- src/RPC/utils.ts | 253 +++++++++++------------------- tests/RPC/RPC.test.ts | 7 +- tests/RPC/RPCClient.test.ts | 50 +++--- tests/RPC/RPCServer.test.ts | 160 +++++++++---------- tests/RPC/utils.test.ts | 33 ++-- tests/RPC/utils.ts | 151 +++++++----------- tests/clientRPC/websocket.test.ts | 12 +- 11 files changed, 428 insertions(+), 546 deletions(-) diff --git a/src/RPC/RPCClient.ts b/src/RPC/RPCClient.ts index b515baa9d..b8f6b4276 100644 --- a/src/RPC/RPCClient.ts +++ b/src/RPC/RPCClient.ts @@ -21,7 +21,6 @@ import { CreateDestroy, ready } from '@matrixai/async-init/dist/CreateDestroy'; import Logger from '@matrixai/logger'; import * as rpcErrors from './errors'; import * as rpcUtils from './utils'; -import { getHandlerTypes } from './utils'; // eslint-disable-next-line interface RPCClient extends CreateDestroy {} @@ -50,7 +49,7 @@ class RPCClient { protected streamPairCreateCallback: StreamPairCreateCallback; protected callerTypes: Record; // Method proxies - protected methodsProxy = new Proxy( + public readonly methodsProxy = new Proxy( {}, { get: (_, method) => { @@ -103,7 +102,7 @@ class RPCClient { streamPairCreateCallback: StreamPairCreateCallback; logger: Logger; }) { - this.callerTypes = getHandlerTypes(manifest); + this.callerTypes = rpcUtils.getHandlerTypes(manifest); this.streamPairCreateCallback = streamPairCreateCallback; this.logger = logger; } @@ -147,9 +146,9 @@ class RPCClient { ): Promise> { // Creating caller side transforms const outputMessageTransforStream = - new rpcUtils.ClientOutputTransformerStream(); + rpcUtils.clientOutputTransformStream(); const inputMessageTransformStream = - new rpcUtils.ClientInputTransformerStream(method); + rpcUtils.clientInputTransformStream(method); let reverseStream = outputMessageTransforStream.writable; let forwardStream = inputMessageTransformStream.readable; // Setting up middleware chains @@ -163,12 +162,12 @@ class RPCClient { const streamPair = await this.streamPairCreateCallback(); void streamPair.readable .pipeThrough( - new rpcUtils.JsonToJsonMessageStream(rpcUtils.parseJsonRpcResponse), + rpcUtils.binaryToJsonMessageStream(rpcUtils.parseJsonRpcResponse), ) .pipeTo(reverseStream) .catch(() => {}); void forwardStream - .pipeThrough(new rpcUtils.JsonMessageToJsonStream()) + .pipeThrough(rpcUtils.jsonMessageToBinaryStream()) .pipeTo(streamPair.writable) .catch(() => {}); @@ -300,20 +299,20 @@ class RPCClient { protected middleware: Array< MiddlewareFactory< - JsonRpcRequest, - JsonRpcRequest, - JsonRpcResponse, - JsonRpcResponse + JsonRpcRequest, + JsonRpcRequest, + JsonRpcResponse, + JsonRpcResponse > > = []; @ready(new rpcErrors.ErrorRpcDestroyed()) public registerMiddleware( middlewareFactory: MiddlewareFactory< - JsonRpcRequest, - JsonRpcRequest, - JsonRpcResponse, - JsonRpcResponse + JsonRpcRequest, + JsonRpcRequest, + JsonRpcResponse, + JsonRpcResponse >, ) { this.middleware.push(middlewareFactory); diff --git a/src/RPC/RPCServer.ts b/src/RPC/RPCServer.ts index 81a8ea4b7..d37ba0d66 100644 --- a/src/RPC/RPCServer.ts +++ b/src/RPC/RPCServer.ts @@ -37,10 +37,10 @@ class RPCServer { manifest: Manifest; container: POJO; middleware?: MiddlewareFactory< - JsonRpcRequest, + JsonRpcRequest, Uint8Array, Uint8Array, - JsonRpcResponseResult + JsonRpcResponseResult >; logger?: Logger; }): Promise { @@ -62,10 +62,10 @@ class RPCServer { protected activeStreams: Set> = new Set(); protected events: EventTarget = new EventTarget(); protected middleware: MiddlewareFactory< - JsonRpcRequest, + JsonRpcRequest, Uint8Array, Uint8Array, - JsonRpcResponseResult + JsonRpcResponseResult >; public constructor({ @@ -77,10 +77,10 @@ class RPCServer { manifest: Manifest; container: POJO; middleware: MiddlewareFactory< - JsonRpcRequest, + JsonRpcRequest, Uint8Array, Uint8Array, - JsonRpcResponseResult + JsonRpcResponseResult >; logger: Logger; }) { @@ -147,9 +147,7 @@ class RPCServer { const forwardStream = input.pipeThrough(middleware.forward); const reverseStream = middleware.reverse.writable; const events = this.events; - const outputGen = async function* (): AsyncGenerator< - JsonRpcResponse - > { + const outputGen = async function* (): AsyncGenerator { if (ctx.signal.aborted) throw ctx.signal.reason; const dataGen = async function* () { for await (const data of forwardStream) { @@ -162,7 +160,7 @@ class RPCServer { connectionInfo, ctx, )) { - const responseMessage: JsonRpcResponseResult = { + const responseMessage: JsonRpcResponseResult = { jsonrpc: '2.0', result: response, id: null, @@ -171,9 +169,7 @@ class RPCServer { } }; const outputGenerator = outputGen(); - const reverseMiddlewareStream = new ReadableStream< - JsonRpcResponse - >({ + const reverseMiddlewareStream = new ReadableStream({ pull: async (controller) => { try { const { value, done } = await outputGenerator.next(); @@ -206,6 +202,9 @@ class RPCServer { }), ); } + await forwardStream.cancel( + new rpcErrors.ErrorRpcHandlerFailed('Error clean up'), + ); controller.close(); } }, @@ -280,60 +279,63 @@ class RPCServer { // This will take a buffer stream of json messages and set up service // handling for it. // Constructing the PromiseCancellable for tracking the active stream - const handlerProm: PromiseCancellable = new PromiseCancellable( - (resolve, reject, signal) => { - const prom = (async () => { - const { firstMessageProm, headTransformStream } = - rpcUtils.extractFirstMessageTransform(rpcUtils.parseJsonRpcRequest); - const inputStreamEndProm = streamPair.readable.pipeTo( - headTransformStream.writable, - ); - const inputStream = headTransformStream.readable; - // Read a single empty value to consume the first message - const reader = inputStream.getReader(); - await reader.read(); - reader.releaseLock(); - const leadingMetadataMessage = await firstMessageProm; - // If the stream ends early then we just stop processing - if (leadingMetadataMessage == null) { - await inputStream.cancel(); - await streamPair.writable.close(); - await inputStreamEndProm; - return; - } - const method = leadingMetadataMessage.method; - const handler = this.handlerMap.get(method); - if (handler == null) { - await inputStream.cancel(); - await streamPair.writable.close(); - await inputStreamEndProm; - return; - } - if (signal.aborted) { - await inputStream.cancel(); - await streamPair.writable.close(); - await inputStreamEndProm; - return; - } - const outputStream = handler( - [inputStream, leadingMetadataMessage], - this.container, - connectionInfo, - { signal }, - ); - await Promise.allSettled([ - inputStreamEndProm, - outputStream.pipeTo(streamPair.writable), - ]); - })(); - prom.then(resolve, reject); - }, + const abortController = new AbortController(); + const prom = (async () => { + const { firstMessageProm, headTransformStream } = + rpcUtils.extractFirstMessageTransform(rpcUtils.parseJsonRpcRequest); + const inputStreamEndProm = streamPair.readable + .pipeTo(headTransformStream.writable) + .catch(() => {}); + const inputStream = headTransformStream.readable; + // Read a single empty value to consume the first message + const reader = inputStream.getReader(); + await reader.read(); + reader.releaseLock(); + const leadingMetadataMessage = await firstMessageProm; + // If the stream ends early then we just stop processing + if (leadingMetadataMessage == null) { + await inputStream.cancel( + new rpcErrors.ErrorRpcHandlerFailed('Missing header'), + ); + await streamPair.writable.close(); + await inputStreamEndProm; + return; + } + const method = leadingMetadataMessage.method; + const handler = this.handlerMap.get(method); + if (handler == null) { + await inputStream.cancel( + new rpcErrors.ErrorRpcHandlerFailed('Missing handler'), + ); + await streamPair.writable.close(); + await inputStreamEndProm; + return; + } + if (abortController.signal.aborted) { + await inputStream.cancel( + new rpcErrors.ErrorRpcHandlerFailed('Aborted'), + ); + await streamPair.writable.close(); + await inputStreamEndProm; + return; + } + const outputStream = handler( + [inputStream, leadingMetadataMessage], + this.container, + connectionInfo, + { signal: abortController.signal }, + ); + await Promise.allSettled([ + inputStreamEndProm, + outputStream.pipeTo(streamPair.writable), + ]); + })(); + const handlerProm = PromiseCancellable.from(prom).finally( + () => this.activeStreams.delete(handlerProm), + abortController, ); // Putting the PromiseCancellable into the active streams map this.activeStreams.add(handlerProm); - void handlerProm - .finally(() => this.activeStreams.delete(handlerProm)) - .catch(() => {}); } @ready(new rpcErrors.ErrorRpcDestroyed()) diff --git a/src/RPC/errors.ts b/src/RPC/errors.ts index d0e2b8e7e..c434efdb8 100644 --- a/src/RPC/errors.ts +++ b/src/RPC/errors.ts @@ -2,21 +2,11 @@ import { ErrorPolykey, sysexits } from '../errors'; class ErrorRpc extends ErrorPolykey {} -class ErrorRpcRunning extends ErrorRpc { - static description = 'Rpc is running'; - exitCode = sysexits.USAGE; -} - class ErrorRpcDestroyed extends ErrorRpc { static description = 'Rpc is destroyed'; exitCode = sysexits.USAGE; } -class ErrorRpcNotRunning extends ErrorRpc { - static description = 'Rpc is not running'; - exitCode = sysexits.USAGE; -} - class ErrorRpcStopping extends ErrorRpc { static description = 'Rpc is stopping'; exitCode = sysexits.USAGE; @@ -27,14 +17,12 @@ class ErrorRpcParse extends ErrorRpc { exitCode = sysexits.SOFTWARE; } -class ErrorRpcHandlerMissing extends ErrorRpc { - static description = 'No handler was registered for the given method'; - exitCode = sysexits.USAGE; -} - -class ErrorRpcProtocal extends ErrorRpc { - static description = 'Unexpected behaviour during communication'; - exitCode = sysexits.PROTOCOL; +/** + * This is an internal error, it should not reach the top level. + */ +class ErrorRpcHandlerFailed extends ErrorRpc { + static description = 'Failed to handle stream'; + exitCode = sysexits.SOFTWARE; } class ErrorRpcMessageLength extends ErrorRpc { @@ -58,13 +46,10 @@ class ErrorRpcPlaceholderConnectionError extends ErrorRpcNoMessageError { export { ErrorRpc, - ErrorRpcRunning, ErrorRpcDestroyed, - ErrorRpcNotRunning, ErrorRpcStopping, ErrorRpcParse, - ErrorRpcHandlerMissing, - ErrorRpcProtocal, + ErrorRpcHandlerFailed, ErrorRpcMessageLength, ErrorRpcRemoteError, ErrorRpcNoMessageError, diff --git a/src/RPC/types.ts b/src/RPC/types.ts index 312e65973..8d1d6e27b 100644 --- a/src/RPC/types.ts +++ b/src/RPC/types.ts @@ -10,7 +10,7 @@ import type { /** * This is the JSON RPC request object. this is the generic message type used for the RPC. */ -type JsonRpcRequestMessage = { +type JsonRpcRequestMessage = { // A String specifying the version of the JSON-RPC protocol. MUST be exactly "2.0" jsonrpc: '2.0'; // A String containing the name of the method to be invoked. Method names that begin with the word rpc followed by a @@ -26,7 +26,7 @@ type JsonRpcRequestMessage = { id: string | number | null; }; -type JsonRpcRequestNotification = { +type JsonRpcRequestNotification = { // A String specifying the version of the JSON-RPC protocol. MUST be exactly "2.0" jsonrpc: '2.0'; // A String containing the name of the method to be invoked. Method names that begin with the word rpc followed by a @@ -38,7 +38,7 @@ type JsonRpcRequestNotification = { params?: T; }; -type JsonRpcResponseResult = { +type JsonRpcResponseResult = { // A String specifying the version of the JSON-RPC protocol. MUST be exactly "2.0". jsonrpc: '2.0'; // This member is REQUIRED on success. @@ -91,15 +91,15 @@ type JsonRpcError = { data?: JSONValue; }; -type JsonRpcRequest = +type JsonRpcRequest = | JsonRpcRequestMessage | JsonRpcRequestNotification; -type JsonRpcResponse = +type JsonRpcResponse = | JsonRpcResponseResult | JsonRpcResponseError; -type JsonRpcMessage = +type JsonRpcMessage = | JsonRpcRequest | JsonRpcResponse; @@ -111,57 +111,57 @@ type Handler = ( ctx: ContextCancellable, ) => O; type RawDuplexStreamHandler = Handler< - [ReadableStream, JsonRpcRequest], + [ReadableStream, JsonRpcRequest], ReadableStream >; -type DuplexStreamHandler = Handler< - AsyncGenerator, - AsyncGenerator ->; -type ServerStreamHandler = Handler< - I, - AsyncGenerator ->; -type ClientStreamHandler = Handler< - AsyncGenerator, - Promise ->; -type UnaryHandler = Handler< - I, - Promise ->; +type DuplexStreamHandler< + I extends JSONValue = JSONValue, + O extends JSONValue = JSONValue, +> = Handler, AsyncGenerator>; +type ServerStreamHandler< + I extends JSONValue = JSONValue, + O extends JSONValue = JSONValue, +> = Handler>; +type ClientStreamHandler< + I extends JSONValue = JSONValue, + O extends JSONValue = JSONValue, +> = Handler, Promise>; +type UnaryHandler< + I extends JSONValue = JSONValue, + O extends JSONValue = JSONValue, +> = Handler>; type StreamPairCreateCallback = () => Promise< ReadableWritablePair >; -type MiddlewareFactory = ( - header?: JsonRpcRequest, -) => { +type MiddlewareFactory = (header?: JsonRpcRequest) => { forward: ReadableWritablePair; reverse: ReadableWritablePair; }; type DuplexStreamCaller< - I extends JSONValue, - O extends JSONValue, + I extends JSONValue = JSONValue, + O extends JSONValue = JSONValue, > = () => Promise>; -type ServerStreamCaller = ( - parameters: I, -) => Promise>; +type ServerStreamCaller< + I extends JSONValue = JSONValue, + O extends JSONValue = JSONValue, +> = (parameters: I) => Promise>; type ClientStreamCaller< - I extends JSONValue, - O extends JSONValue, + I extends JSONValue = JSONValue, + O extends JSONValue = JSONValue, > = () => Promise<{ output: Promise; writable: WritableStream; }>; -type UnaryCaller = ( - parameters: I, -) => Promise; +type UnaryCaller< + I extends JSONValue = JSONValue, + O extends JSONValue = JSONValue, +> = (parameters: I) => Promise; type RawStreamCaller = ( params: JSONValue, @@ -192,30 +192,35 @@ type ConvertUnaryCaller = T extends UnaryHandler ? UnaryCaller : never; -type ConvertHandler = T extends DuplexStreamHandler +type ConvertHandler = T extends DuplexStreamHandler ? ConvertDuplexStreamHandler - : T extends ServerStreamHandler + : T extends ServerStreamHandler ? ConvertServerStreamHandler - : T extends ClientStreamHandler + : T extends ClientStreamHandler ? ConvertClientStreamHandler - : T extends UnaryHandler + : T extends UnaryHandler ? ConvertUnaryCaller : T extends RawDuplexStreamHandler ? RawStreamCaller : never; -type WithDuplexStreamCaller = ( - f: (output: AsyncGenerator) => AsyncGenerator, -) => Promise; +type WithDuplexStreamCaller< + I extends JSONValue = JSONValue, + O extends JSONValue = JSONValue, +> = (f: (output: AsyncGenerator) => AsyncGenerator) => Promise; -type WithServerStreamCaller = ( +type WithServerStreamCaller< + I extends JSONValue = JSONValue, + O extends JSONValue = JSONValue, +> = ( parameters: I, f: (output: AsyncGenerator) => Promise, ) => Promise; -type WithClientStreamCaller = ( - f: () => AsyncGenerator, -) => Promise; +type WithClientStreamCaller< + I extends JSONValue = JSONValue, + O extends JSONValue = JSONValue, +> = (f: () => AsyncGenerator) => Promise; type WithRawStreamCaller = ( params: JSONValue, @@ -243,11 +248,11 @@ type ConvertWithClientStreamHandler = T extends ClientStreamHandler< ? WithClientStreamCaller : never; -type ConvertWithHandler = T extends DuplexStreamHandler +type ConvertWithHandler = T extends DuplexStreamHandler ? ConvertWithDuplexStreamHandler - : T extends ServerStreamHandler + : T extends ServerStreamHandler ? ConvertWithServerStreamHandler - : T extends ClientStreamHandler + : T extends ClientStreamHandler ? ConvertWithClientStreamHandler : T extends RawDuplexStreamHandler ? WithRawStreamCaller @@ -255,7 +260,10 @@ type ConvertWithHandler = T extends DuplexStreamHandler type HandlerType = 'DUPLEX' | 'SERVER' | 'CLIENT' | 'UNARY' | 'RAW'; -type ManifestItem = +type ManifestItem< + I extends JSONValue = JSONValue, + O extends JSONValue = JSONValue, +> = | { type: 'DUPLEX'; handler: DuplexStreamHandler; @@ -277,18 +285,14 @@ type ManifestItem = handler: RawDuplexStreamHandler; }; -type Manifest = Record>; - -type ExtractHandler = T extends ManifestItem - ? T['handler'] - : never; +type Manifest = Record; type MapHandlers = { - [P in keyof T]: ConvertHandler>; + [K in keyof T]: ConvertHandler; }; type MapWithHandlers = { - [P in keyof T]: ConvertWithHandler>; + [K in keyof T]: ConvertWithHandler; }; export type { diff --git a/src/RPC/utils.ts b/src/RPC/utils.ts index 9cb9479ce..a5d87b143 100644 --- a/src/RPC/utils.ts +++ b/src/RPC/utils.ts @@ -1,9 +1,3 @@ -import type { - Transformer, - TransformerTransformCallback, - TransformerStartCallback, - TransformerFlushCallback, -} from 'stream/web'; import type { JsonRpcError, JsonRpcMessage, @@ -16,7 +10,6 @@ import type { MiddlewareFactory, } from 'RPC/types'; import type { JSONValue } from '../types'; -import type { JsonValue } from 'fast-check'; import type { HandlerType, Manifest } from 'RPC/types'; import { TransformStream } from 'stream/web'; import { AbstractError } from '@matrixai/errors'; @@ -27,75 +20,46 @@ import * as errors from '../errors'; import { promise } from '../utils'; const jsonStreamParsers = require('@streamparser/json'); -class JsonToJsonMessage - implements Transformer -{ - protected bytesWritten: number = 0; - - constructor( - protected messageParser: (message: unknown) => T, - protected byteLimit: number, - protected firstMessage: T | undefined, - ) {} - - protected parser = new jsonStreamParsers.JSONParser({ +function binaryToJsonMessageStream( + messageParser: (message: unknown) => T, + byteLimit: number = 1024 * 1024, + firstMessage?: T, +) { + const parser = new jsonStreamParsers.JSONParser({ separator: '', paths: ['$'], }); + let bytesWritten: number = 0; - start: TransformerStartCallback = async (controller) => { - if (this.firstMessage != null) controller.enqueue(this.firstMessage); - this.parser.onValue = (value) => { - const jsonMessage = this.messageParser(value.value); - controller.enqueue(jsonMessage); - this.bytesWritten = 0; - }; - }; - - transform: TransformerTransformCallback = async (chunk) => { - try { - this.bytesWritten += chunk.byteLength; - this.parser.write(chunk); - } catch (e) { - throw new rpcErrors.ErrorRpcParse(undefined, { cause: e }); - } - if (this.bytesWritten > this.byteLimit) { - throw new rpcErrors.ErrorRpcMessageLength(); - } - }; -} - -// TODO: rename to something more descriptive? -class JsonToJsonMessageStream extends TransformStream< - Uint8Array, - T -> { - constructor( - messageParser: (message: unknown) => T, - byteLimit: number = 1024 * 1024, - firstMessage?: T, - ) { - super(new JsonToJsonMessage(messageParser, byteLimit, firstMessage)); - } -} - -class JsonMessageToJson implements Transformer { - transform: TransformerTransformCallback = async ( - chunk, - controller, - ) => { - controller.enqueue(Buffer.from(JSON.stringify(chunk))); - }; + return new TransformStream({ + start: (controller) => { + if (firstMessage != null) controller.enqueue(firstMessage); + parser.onValue = (value) => { + const jsonMessage = messageParser(value.value); + controller.enqueue(jsonMessage); + bytesWritten = 0; + }; + }, + transform: (chunk) => { + try { + bytesWritten += chunk.byteLength; + parser.write(chunk); + } catch (e) { + throw new rpcErrors.ErrorRpcParse(undefined, { cause: e }); + } + if (bytesWritten > byteLimit) { + throw new rpcErrors.ErrorRpcMessageLength(); + } + }, + }); } -// TODO: rename to something more descriptive? -class JsonMessageToJsonStream extends TransformStream< - JsonRpcMessage, - Uint8Array -> { - constructor() { - super(new JsonMessageToJson()); - } +function jsonMessageToBinaryStream() { + return new TransformStream({ + transform: (chunk, controller) => { + controller.enqueue(Buffer.from(JSON.stringify(chunk))); + }, + }); } function parseJsonRpcRequest( @@ -436,52 +400,29 @@ function toError(errorData) { }); } -class ClientInputTransformer - implements Transformer> -{ - constructor(protected method: string) {} - - transform: TransformerTransformCallback> = - async (chunk, controller) => { - const message: JsonRpcRequest = { - method: this.method, +function clientInputTransformStream(method: string) { + return new TransformStream({ + transform: (chunk, controller) => { + const message: JsonRpcRequest = { + method, jsonrpc: '2.0', id: null, params: chunk, }; controller.enqueue(message); - }; -} - -class ClientInputTransformerStream extends TransformStream< - I, - JsonRpcRequest -> { - constructor(method: string) { - super(new ClientInputTransformer(method)); - } -} - -class ClientOutputTransformer - implements Transformer, O> -{ - transform: TransformerTransformCallback, O> = async ( - chunk, - controller, - ) => { - if ('error' in chunk) { - throw toError(chunk.error.data); - } - controller.enqueue(chunk.result); - }; + }, + }); } -class ClientOutputTransformerStream< - O extends JSONValue, -> extends TransformStream, O> { - constructor() { - super(new ClientOutputTransformer()); - } +function clientOutputTransformStream() { + return new TransformStream, O>({ + transform: (chunk, controller) => { + if ('error' in chunk) { + throw toError(chunk.error.data); + } + controller.enqueue(chunk.result); + }, + }); } function isReturnableError(e: Error): boolean { @@ -534,39 +475,31 @@ const controllerTransformationFactory = () => { }; }; -class QueueMergingTransform implements Transformer { - constructor(protected messageQueue: Array) {} - - start: TransformerStartCallback = async (controller) => { - while (true) { - const value = this.messageQueue.shift(); - if (value == null) break; - controller.enqueue(value); - } - }; - - transform: TransformerTransformCallback = async (chunk, controller) => { - while (true) { - const value = this.messageQueue.shift(); - if (value == null) break; - controller.enqueue(value); - } - controller.enqueue(chunk); - }; - - flush: TransformerFlushCallback = (controller) => { - while (true) { - const value = this.messageQueue.shift(); - if (value == null) break; - controller.enqueue(value); - } - }; -} - -class QueueMergingTransformStream extends TransformStream { - constructor(messageQueue: Array) { - super(new QueueMergingTransform(messageQueue)); - } +function queueMergingTransformStream(messageQueue: Array) { + return new TransformStream({ + start: (controller) => { + while (true) { + const value = messageQueue.shift(); + if (value == null) break; + controller.enqueue(value); + } + }, + transform: (chunk, controller) => { + while (true) { + const value = messageQueue.shift(); + if (value == null) break; + controller.enqueue(value); + } + controller.enqueue(chunk); + }, + flush: (controller) => { + while (true) { + const value = messageQueue.shift(); + if (value == null) break; + controller.enqueue(value); + } + }, + }); } function extractFirstMessageTransform( @@ -639,10 +572,10 @@ function getHandlerTypes(manifest: Manifest): Record { } const defaultMiddleware: MiddlewareFactory< - JsonRpcRequest, - JsonRpcRequest, - JsonRpcResponse, - JsonRpcResponse + JsonRpcRequest, + JsonRpcRequest, + JsonRpcResponse, + JsonRpcResponse > = () => { return { forward: new TransformStream(), @@ -652,21 +585,21 @@ const defaultMiddleware: MiddlewareFactory< const defaultMiddlewareWrapper = ( middleware: MiddlewareFactory< - JsonRpcRequest, - JsonRpcRequest, - JsonRpcResponse, - JsonRpcResponse + JsonRpcRequest, + JsonRpcRequest, + JsonRpcResponse, + JsonRpcResponse > = defaultMiddleware, ) => { - return (header: JsonRpcRequest) => { - const inputTransformStream = new JsonToJsonMessageStream( + return (header: JsonRpcRequest) => { + const inputTransformStream = binaryToJsonMessageStream( parseJsonRpcRequest, undefined, header, ); const outputTransformStream = new TransformStream< - JsonRpcResponseResult, - JsonRpcResponseResult + JsonRpcResponseResult, + JsonRpcResponseResult >(); const middleMiddleware = middleware(header); @@ -676,7 +609,7 @@ const defaultMiddlewareWrapper = ( ); // Usual middleware here const reverseReadable = outputTransformStream.readable .pipeThrough(middleMiddleware.reverse) // Usual middleware here - .pipeThrough(new JsonMessageToJsonStream()); + .pipeThrough(jsonMessageToBinaryStream()); return { forward: { @@ -692,8 +625,8 @@ const defaultMiddlewareWrapper = ( }; export { - JsonToJsonMessageStream, - JsonMessageToJsonStream, + binaryToJsonMessageStream, + jsonMessageToBinaryStream, parseJsonRpcRequest, parseJsonRpcRequestMessage, parseJsonRpcRequestNotification, @@ -703,12 +636,12 @@ export { parseJsonRpcMessage, fromError, toError, - ClientInputTransformerStream, - ClientOutputTransformerStream, + clientInputTransformStream, + clientOutputTransformStream, isReturnableError, RPCErrorEvent, controllerTransformationFactory, - QueueMergingTransformStream, + queueMergingTransformStream, extractFirstMessageTransform, getHandlerTypes, defaultMiddleware, diff --git a/tests/RPC/RPC.test.ts b/tests/RPC/RPC.test.ts index f7efb9369..37516a630 100644 --- a/tests/RPC/RPC.test.ts +++ b/tests/RPC/RPC.test.ts @@ -1,6 +1,5 @@ import type { JsonRpcRequest, ManifestItem } from '@/RPC/types'; import type { ConnectionInfo } from '@/network/types'; -import type { JSONValue } from '@/types'; import { fc, testProp } from '@fast-check/jest'; import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; import RPCServer from '@/RPC/RPCServer'; @@ -22,7 +21,7 @@ describe('RPC', () => { >(); let header: JsonRpcRequest | undefined; - const testMethod: ManifestItem = { + const testMethod: ManifestItem = { type: 'RAW', handler: ([input, header_], _container, _connectionInfo, _ctx) => { header = header_; @@ -77,7 +76,7 @@ describe('RPC', () => { Uint8Array >(); - const testMethod: ManifestItem = { + const testMethod: ManifestItem = { type: 'DUPLEX', handler: async function* (input, _container, _connectionInfo, _ctx) { for await (const val of input) { @@ -220,7 +219,7 @@ describe('RPC', () => { Uint8Array >(); - const testMethod: ManifestItem = { + const testMethod: ManifestItem = { type: 'UNARY', handler: async (input) => input, }; diff --git a/tests/RPC/RPCClient.test.ts b/tests/RPC/RPCClient.test.ts index 863bb99fa..c5fd9212b 100644 --- a/tests/RPC/RPCClient.test.ts +++ b/tests/RPC/RPCClient.test.ts @@ -79,7 +79,7 @@ describe(`${RPCClient.name}`, () => { }, ); testProp('generic duplex caller', [specificMessageArb], async (messages) => { - const inputStream = rpcTestUtils.jsonRpcStream(messages); + const inputStream = rpcTestUtils.messagesToReadableStream(messages); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); const streamPair: ReadableWritablePair = { @@ -125,7 +125,7 @@ describe(`${RPCClient.name}`, () => { 'generic server stream caller', [specificMessageArb, rpcTestUtils.safeJsonValueArb], async (messages, params) => { - const inputStream = rpcTestUtils.jsonRpcStream(messages); + const inputStream = rpcTestUtils.messagesToReadableStream(messages); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); const streamPair: ReadableWritablePair = { readable: inputStream, @@ -164,7 +164,7 @@ describe(`${RPCClient.name}`, () => { fc.array(rpcTestUtils.safeJsonValueArb), ], async (message, params) => { - const inputStream = rpcTestUtils.jsonRpcStream([message]); + const inputStream = rpcTestUtils.messagesToReadableStream([message]); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); const streamPair: ReadableWritablePair = { @@ -204,7 +204,7 @@ describe(`${RPCClient.name}`, () => { 'generic unary caller', [rpcTestUtils.jsonRpcResponseResultArb(), rpcTestUtils.safeJsonValueArb], async (message, params) => { - const inputStream = rpcTestUtils.jsonRpcStream([message]); + const inputStream = rpcTestUtils.messagesToReadableStream([message]); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); const streamPair: ReadableWritablePair = { readable: inputStream, @@ -238,7 +238,7 @@ describe(`${RPCClient.name}`, () => { rpcTestUtils.jsonRpcResponseErrorArb(), ], async (messages, errorMessage) => { - const inputStream = rpcTestUtils.jsonRpcStream([ + const inputStream = rpcTestUtils.messagesToReadableStream([ ...messages, errorMessage, ]); @@ -326,7 +326,7 @@ describe(`${RPCClient.name}`, () => { 'withDuplexCaller', [fc.array(rpcTestUtils.jsonRpcResponseResultArb(), { minLength: 1 })], async (messages) => { - const inputStream = rpcTestUtils.jsonRpcStream(messages); + const inputStream = rpcTestUtils.messagesToReadableStream(messages); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); const streamPair: ReadableWritablePair = { @@ -346,7 +346,7 @@ describe(`${RPCClient.name}`, () => { } }); const result = await outputResult; - // We're just checking that it consuming the messages as expected + // We're just checking that it's consuming the messages as expected expect(result.length).toEqual(messages.length); expect(count).toEqual(messages.length); await rpcClient.destroy(); @@ -359,7 +359,7 @@ describe(`${RPCClient.name}`, () => { rpcTestUtils.safeJsonValueArb, ], async (messages, params) => { - const inputStream = rpcTestUtils.jsonRpcStream(messages); + const inputStream = rpcTestUtils.messagesToReadableStream(messages); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); const streamPair: ReadableWritablePair = { @@ -395,7 +395,7 @@ describe(`${RPCClient.name}`, () => { fc.array(rpcTestUtils.safeJsonValueArb, { minLength: 2 }).noShrink(), ], async (message, inputMessages) => { - const inputStream = rpcTestUtils.jsonRpcStream([message]); + const inputStream = rpcTestUtils.messagesToReadableStream([message]); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); const streamPair: ReadableWritablePair = { @@ -434,7 +434,7 @@ describe(`${RPCClient.name}`, () => { 'generic duplex caller with forward Middleware', [specificMessageArb], async (messages) => { - const inputStream = rpcTestUtils.jsonRpcStream(messages); + const inputStream = rpcTestUtils.messagesToReadableStream(messages); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); const streamPair: ReadableWritablePair = { @@ -449,10 +449,7 @@ describe(`${RPCClient.name}`, () => { rpcClient.registerMiddleware(() => { return { - forward: new TransformStream< - JsonRpcRequest, - JsonRpcRequest - >({ + forward: new TransformStream({ transform: (chunk, controller) => { controller.enqueue({ ...chunk, @@ -501,7 +498,7 @@ describe(`${RPCClient.name}`, () => { 'generic duplex caller with reverse Middleware', [specificMessageArb], async (messages) => { - const inputStream = rpcTestUtils.jsonRpcStream(messages); + const inputStream = rpcTestUtils.messagesToReadableStream(messages); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); const streamPair: ReadableWritablePair = { @@ -517,10 +514,7 @@ describe(`${RPCClient.name}`, () => { rpcClient.registerMiddleware(() => { return { forward: new TransformStream(), - reverse: new TransformStream< - JsonRpcResponse, - JsonRpcResponse - >({ + reverse: new TransformStream({ transform: (chunk, controller) => { controller.enqueue({ ...chunk, @@ -558,7 +552,7 @@ describe(`${RPCClient.name}`, () => { }), ], async (messages) => { - const inputStream = rpcTestUtils.jsonRpcStream(messages); + const inputStream = rpcTestUtils.messagesToReadableStream(messages); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); const streamPair: ReadableWritablePair = { readable: inputStream, @@ -612,7 +606,7 @@ describe(`${RPCClient.name}`, () => { 'manifest server call', [specificMessageArb, fc.string()], async (messages, params) => { - const inputStream = rpcTestUtils.jsonRpcStream(messages); + const inputStream = rpcTestUtils.messagesToReadableStream(messages); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); const streamPair: ReadableWritablePair = { readable: inputStream, @@ -656,7 +650,7 @@ describe(`${RPCClient.name}`, () => { fc.array(fc.string(), { minLength: 5 }), ], async (message, params) => { - const inputStream = rpcTestUtils.jsonRpcStream([message]); + const inputStream = rpcTestUtils.messagesToReadableStream([message]); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); const streamPair: ReadableWritablePair = { @@ -701,7 +695,7 @@ describe(`${RPCClient.name}`, () => { 'manifest unary call', [rpcTestUtils.jsonRpcResponseResultArb().noShrink(), fc.string()], async (message, params) => { - const inputStream = rpcTestUtils.jsonRpcStream([message]); + const inputStream = rpcTestUtils.messagesToReadableStream([message]); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); const streamPair: ReadableWritablePair = { readable: inputStream, @@ -754,7 +748,7 @@ describe(`${RPCClient.name}`, () => { }), writable: inputWritableStream, }; - const raw: ManifestItem = { + const raw: ManifestItem = { type: 'RAW', handler: ([input]) => input, }; @@ -794,7 +788,7 @@ describe(`${RPCClient.name}`, () => { }), ], async (messages) => { - const inputStream = rpcTestUtils.jsonRpcStream(messages); + const inputStream = rpcTestUtils.messagesToReadableStream(messages); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); const streamPair: ReadableWritablePair = { @@ -835,7 +829,7 @@ describe(`${RPCClient.name}`, () => { fc.string(), ], async (messages, params) => { - const inputStream = rpcTestUtils.jsonRpcStream(messages); + const inputStream = rpcTestUtils.messagesToReadableStream(messages); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); const streamPair: ReadableWritablePair = { @@ -879,7 +873,7 @@ describe(`${RPCClient.name}`, () => { fc.array(fc.string(), { minLength: 2 }).noShrink(), ], async (message, inputMessages) => { - const inputStream = rpcTestUtils.jsonRpcStream([message]); + const inputStream = rpcTestUtils.messagesToReadableStream([message]); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); const streamPair: ReadableWritablePair = { @@ -940,7 +934,7 @@ describe(`${RPCClient.name}`, () => { }), writable: inputWritableStream, }; - const raw: ManifestItem = { + const raw: ManifestItem = { type: 'RAW', handler: ([input]) => input, }; diff --git a/tests/RPC/RPCServer.test.ts b/tests/RPC/RPCServer.test.ts index c4f35b6dc..d21215026 100644 --- a/tests/RPC/RPCServer.test.ts +++ b/tests/RPC/RPCServer.test.ts @@ -1,5 +1,4 @@ import type { - JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, JsonRpcResponseError, @@ -39,8 +38,6 @@ describe(`${RPCServer.name}`, () => { ); const errorArb = fc.oneof( fc.constant(new rpcErrors.ErrorRpcParse()), - fc.constant(new rpcErrors.ErrorRpcHandlerMissing()), - fc.constant(new rpcErrors.ErrorRpcProtocal()), fc.constant(new rpcErrors.ErrorRpcMessageLength()), fc.constant(new rpcErrors.ErrorRpcRemoteError()), ); @@ -60,11 +57,11 @@ describe(`${RPCServer.name}`, () => { [specificMessageArb], async (messages) => { const stream = rpcTestUtils - .jsonRpcStream(messages) + .messagesToReadableStream(messages) .pipeThrough( - new rpcTestUtils.BufferStreamToSnippedStream([4, 7, 13, 2, 6]), + rpcTestUtils.binaryStreamToSnippedStream([4, 7, 13, 2, 6]), ); - const testMethod: ManifestItem = { + const testMethod: ManifestItem = { type: 'RAW', handler: ([input]) => { void (async () => { @@ -103,8 +100,8 @@ describe(`${RPCServer.name}`, () => { 'can stream data with duplex stream handler', [specificMessageArb], async (messages) => { - const stream = rpcTestUtils.jsonRpcStream(messages); - const testMethod: ManifestItem = { + const stream = rpcTestUtils.messagesToReadableStream(messages); + const testMethod: ManifestItem = { type: 'DUPLEX', handler: async function* (input, _container, _connectionInfo, _ctx) { for await (const val of input) { @@ -135,8 +132,8 @@ describe(`${RPCServer.name}`, () => { 'can stream data with client stream handler', [specificMessageArb], async (messages) => { - const stream = rpcTestUtils.jsonRpcStream(messages); - const testMethod: ManifestItem = { + const stream = rpcTestUtils.messagesToReadableStream(messages); + const testMethod: ManifestItem = { type: 'CLIENT', handler: async function (input, _container, _connectionInfo, _ctx) { let count = 0; @@ -168,7 +165,7 @@ describe(`${RPCServer.name}`, () => { 'can stream data with server stream handler', [singleNumberMessageArb], async (messages) => { - const stream = rpcTestUtils.jsonRpcStream(messages); + const stream = rpcTestUtils.messagesToReadableStream(messages); const testMethod: ManifestItem = { type: 'SERVER', handler: async function* (input, _container, _connectionInfo, _ctx) { @@ -199,8 +196,8 @@ describe(`${RPCServer.name}`, () => { 'can stream data with server stream handler', [specificMessageArb], async (messages) => { - const stream = rpcTestUtils.jsonRpcStream(messages); - const testMethod: ManifestItem = { + const stream = rpcTestUtils.messagesToReadableStream(messages); + const testMethod: ManifestItem = { type: 'UNARY', handler: async (input, _container, _connectionInfo, _ctx) => input, }; @@ -226,8 +223,8 @@ describe(`${RPCServer.name}`, () => { 'Handler is provided with container', [specificMessageArb], async (messages) => { - const stream = rpcTestUtils.jsonRpcStream(messages); - const testMethod: ManifestItem = { + const stream = rpcTestUtils.messagesToReadableStream(messages); + const testMethod: ManifestItem = { type: 'DUPLEX', handler: async function* (input, container_, _connectionInfo, _ctx) { expect(container_).toBe(container); @@ -262,7 +259,7 @@ describe(`${RPCServer.name}`, () => { 'Handler is provided with connectionInfo', [specificMessageArb], async (messages) => { - const stream = rpcTestUtils.jsonRpcStream(messages); + const stream = rpcTestUtils.messagesToReadableStream(messages); const connectionInfo: ConnectionInfo = { localHost: 'hostA' as Host, localPort: 12341 as Port, @@ -272,7 +269,7 @@ describe(`${RPCServer.name}`, () => { remotePort: 12341 as Port, }; let handledConnectionInfo; - const testMethod: ManifestItem = { + const testMethod: ManifestItem = { type: 'DUPLEX', handler: async function* (input, _container, connectionInfo_, _ctx) { handledConnectionInfo = connectionInfo_; @@ -302,60 +299,59 @@ describe(`${RPCServer.name}`, () => { ); // Problem with the tap stream. It seems to block the whole stream. // If I don't pipe the tap to the output we actually iterate over some data. - testProp.skip( - 'Handler can be aborted', - [specificMessageArb], - async (messages) => { - const stream = rpcTestUtils.jsonRpcStream(messages); - const testMethod: ManifestItem = { - type: 'DUPLEX', - handler: async function* (input, _container, _connectionInfo, ctx) { - for await (const val of input) { - if (ctx.signal.aborted) throw ctx.signal.reason; - yield val; + testProp('Handler can be aborted', [specificMessageArb], async (messages) => { + const stream = rpcTestUtils.messagesToReadableStream(messages); + const testMethod: ManifestItem = { + type: 'DUPLEX', + handler: async function* (input, _container, _connectionInfo, ctx) { + for await (const val of input) { + if (ctx.signal.aborted) throw ctx.signal.reason; + yield val; + } + }, + }; + const container = {}; + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod, + }, + container, + logger, + }); + const [outputResult, outputStream] = + rpcTestUtils.streamToArray(); + let thing; + const tapStream = rpcTestUtils.tapTransformStream( + async (_, iteration) => { + if (iteration === 2) { + // @ts-ignore: kidnap private property + const activeStreams = rpcServer.activeStreams.values(); + // @ts-ignore: kidnap private property + for (const activeStream of activeStreams) { + thing = activeStream; + activeStream.cancel(new rpcErrors.ErrorRpcStopping()); } - }, - }; - const container = {}; - const rpcServer = await RPCServer.createRPCServer({ - manifest: { - testMethod, - }, - container, - logger, - }); - const [outputResult, outputStream] = rpcTestUtils.streamToArray(); - let thing; - let lastMessage: JsonRpcMessage | undefined; - const tapStream: any = {}; - // Const tapStream = new rpcTestUtils.TapStream( - // async (_, iteration) => { - // if (iteration === 2) { - // // @ts-ignore: kidnap private property - // const activeStreams = rpcServer.activeStreams.values(); - // for (const activeStream of activeStreams) { - // thing = activeStream; - // activeStream.cancel(new rpcErrors.ErrorRpcStopping()); - // } - // } - // }, - // ); - await tapStream.readable.pipeTo(outputStream); - const readWriteStream: ReadableWritablePair = { - readable: stream, - writable: tapStream.writable, - }; - rpcServer.handleStream(readWriteStream, {} as ConnectionInfo); - await outputResult; - await expect(thing).toResolve(); - // Last message should be an error message - expect(lastMessage).toBeDefined(); - await rpcServer.destroy(); - }, - ); + } + }, + ); + void tapStream.readable.pipeTo(outputStream).catch(() => {}); + const readWriteStream: ReadableWritablePair = { + readable: stream, + writable: tapStream.writable, + }; + rpcServer.handleStream(readWriteStream, {} as ConnectionInfo); + const result = await outputResult; + const lastMessage = result[result.length - 1]; + await expect(thing).toResolve(); + expect(lastMessage).toBeDefined(); + expect(() => + rpcUtils.parseJsonRpcResponseError(JSON.parse(lastMessage.toString())), + ).not.toThrow(); + await rpcServer.destroy(); + }); testProp('Handler yields nothing', [specificMessageArb], async (messages) => { - const stream = rpcTestUtils.jsonRpcStream(messages); - const testMethod: ManifestItem = { + const stream = rpcTestUtils.messagesToReadableStream(messages); + const testMethod: ManifestItem = { type: 'DUPLEX', handler: async function* (input, _container, _connectionInfo, _ctx) { for await (const _ of input) { @@ -385,8 +381,8 @@ describe(`${RPCServer.name}`, () => { 'should send error message', [specificMessageArb, errorArb], async (messages, error) => { - const stream = rpcTestUtils.jsonRpcStream(messages); - const testMethod: ManifestItem = { + const stream = rpcTestUtils.messagesToReadableStream(messages); + const testMethod: ManifestItem = { type: 'DUPLEX', handler: async function* (_input, _container, _connectionInfo, _ctx) { throw error; @@ -426,8 +422,8 @@ describe(`${RPCServer.name}`, () => { 'should emit stream error', [specificMessageArb], async (messages) => { - const stream = rpcTestUtils.jsonRpcStream(messages); - const testMethod: ManifestItem = { + const stream = rpcTestUtils.messagesToReadableStream(messages); + const testMethod: ManifestItem = { type: 'DUPLEX', handler: async function* (_input, _container, _connectionInfo, _ctx) { throw new rpcErrors.ErrorRpcPlaceholderConnectionError(); @@ -463,8 +459,8 @@ describe(`${RPCServer.name}`, () => { }, ); testProp('forward middlewares', [specificMessageArb], async (messages) => { - const stream = rpcTestUtils.jsonRpcStream(messages); - const testMethod: ManifestItem = { + const stream = rpcTestUtils.messagesToReadableStream(messages); + const testMethod: ManifestItem = { type: 'DUPLEX', handler: async function* (input, _container, _connectionInfo, _ctx) { for await (const val of input) { @@ -511,8 +507,8 @@ describe(`${RPCServer.name}`, () => { await rpcServer.destroy(); }); testProp('reverse middlewares', [specificMessageArb], async (messages) => { - const stream = rpcTestUtils.jsonRpcStream(messages); - const testMethod: ManifestItem = { + const stream = rpcTestUtils.messagesToReadableStream(messages); + const testMethod: ManifestItem = { type: 'DUPLEX', handler: async function* (input, _container, _connectionInfo, _ctx) { for await (const val of input) { @@ -562,8 +558,8 @@ describe(`${RPCServer.name}`, () => { 'forward middleware authentication', [invalidTokenMessageArb], async (message) => { - const stream = rpcTestUtils.jsonRpcStream([message]); - const testMethod: ManifestItem = { + const stream = rpcTestUtils.messagesToReadableStream([message]); + const testMethod: ManifestItem = { type: 'DUPLEX', handler: async function* (input, _container, _connectionInfo, _ctx) { for await (const val of input) { @@ -574,9 +570,7 @@ describe(`${RPCServer.name}`, () => { const container = {}; const middleware = rpcUtils.defaultMiddlewareWrapper(() => { let first = true; - let reverseController: TransformStreamDefaultController< - JsonRpcResponse - >; + let reverseController: TransformStreamDefaultController; return { forward: new TransformStream< JsonRpcRequest, @@ -628,7 +622,7 @@ describe(`${RPCServer.name}`, () => { id: null, error: { code: 1, - message: 'failure of somekind', + message: 'failure of some kind', }, }; rpcServer.handleStream(readWriteStream, {} as ConnectionInfo); diff --git a/tests/RPC/utils.test.ts b/tests/RPC/utils.test.ts index 63a1bdfec..b0414dcee 100644 --- a/tests/RPC/utils.test.ts +++ b/tests/RPC/utils.test.ts @@ -11,9 +11,9 @@ describe('utils tests', () => { [rpcTestUtils.jsonMessagesArb], async (messages) => { const parsedStream = rpcTestUtils - .jsonRpcStream(messages) + .messagesToReadableStream(messages) .pipeThrough( - new rpcUtils.JsonToJsonMessageStream(rpcUtils.parseJsonRpcMessage), + rpcUtils.binaryToJsonMessageStream(rpcUtils.parseJsonRpcMessage), ); // Converting back. const asd = await AsyncIterable.as(parsedStream).toArray(); @@ -27,10 +27,10 @@ describe('utils tests', () => { [rpcTestUtils.jsonMessagesArb, rpcTestUtils.snippingPatternArb], async (messages, snippattern) => { const parsedStream = rpcTestUtils - .jsonRpcStream(messages) - .pipeThrough(new rpcTestUtils.BufferStreamToSnippedStream(snippattern)) // Imaginary internet here + .messagesToReadableStream(messages) + .pipeThrough(rpcTestUtils.binaryStreamToSnippedStream(snippattern)) // Imaginary internet here .pipeThrough( - new rpcUtils.JsonToJsonMessageStream(rpcUtils.parseJsonRpcMessage), + rpcUtils.binaryToJsonMessageStream(rpcUtils.parseJsonRpcMessage), ); // Converting back. const asd = await AsyncIterable.as(parsedStream).toArray(); @@ -51,11 +51,11 @@ describe('utils tests', () => { [rpcTestUtils.jsonMessagesArb, rpcTestUtils.snippingPatternArb, noiseArb], async (messages, snippattern, noise) => { const parsedStream = rpcTestUtils - .jsonRpcStream(messages) - .pipeThrough(new rpcTestUtils.BufferStreamToSnippedStream(snippattern)) // Imaginary internet here - .pipeThrough(new rpcTestUtils.BufferStreamToNoisyStream(noise)) // Adding bad data to the stream + .messagesToReadableStream(messages) + .pipeThrough(rpcTestUtils.binaryStreamToSnippedStream(snippattern)) // Imaginary internet here + .pipeThrough(rpcTestUtils.binaryStreamToNoisyStream(noise)) // Adding bad data to the stream .pipeThrough( - new rpcUtils.JsonToJsonMessageStream(rpcUtils.parseJsonRpcMessage), + rpcUtils.binaryToJsonMessageStream(rpcUtils.parseJsonRpcMessage), ); // Converting back. await expect(AsyncIterable.as(parsedStream).toArray()).rejects.toThrow( @@ -86,13 +86,10 @@ describe('utils tests', () => { ], async (messages) => { const parsedStream = rpcTestUtils - .jsonRpcStream(messages) - .pipeThrough(new rpcTestUtils.BufferStreamToSnippedStream([10])) + .messagesToReadableStream(messages) + .pipeThrough(rpcTestUtils.binaryStreamToSnippedStream([10])) .pipeThrough( - new rpcUtils.JsonToJsonMessageStream( - rpcUtils.parseJsonRpcMessage, - 50, - ), + rpcUtils.binaryToJsonMessageStream(rpcUtils.parseJsonRpcMessage, 50), ); const doThing = async () => { @@ -143,11 +140,11 @@ describe('utils tests', () => { const { firstMessageProm, headTransformStream } = rpcUtils.extractFirstMessageTransform(rpcUtils.parseJsonRpcRequest); const parsedStream = rpcTestUtils - .jsonRpcStream(messages) - .pipeThrough(new rpcTestUtils.BufferStreamToSnippedStream([7])) + .messagesToReadableStream(messages) + .pipeThrough(rpcTestUtils.binaryStreamToSnippedStream([7])) .pipeThrough(headTransformStream) .pipeThrough( - new rpcUtils.JsonToJsonMessageStream(rpcUtils.parseJsonRpcMessage), + rpcUtils.binaryToJsonMessageStream(rpcUtils.parseJsonRpcMessage), ); // Converting back. expect(await firstMessageProm).toStrictEqual(messages[0]); diff --git a/tests/RPC/utils.ts b/tests/RPC/utils.ts index f8c4aed2f..8cffba4a1 100644 --- a/tests/RPC/utils.ts +++ b/tests/RPC/utils.ts @@ -1,10 +1,5 @@ -import type { - Transformer, - TransformerFlushCallback, - TransformerTransformCallback, - ReadableWritablePair, -} from 'stream/web'; -import type { JSONValue, POJO } from '@/types'; +import type { ReadableWritablePair } from 'stream/web'; +import type { JSONValue } from '@/types'; import type { JsonRpcError, JsonRpcMessage, @@ -20,64 +15,30 @@ import { fc } from '@fast-check/jest'; import * as utils from '@/utils'; import { fromError } from '@/RPC/utils'; -class BufferStreamToSnipped implements Transformer { - protected buffer = Buffer.alloc(0); - protected iteration = 0; - protected snippingPattern: Array; - - constructor(snippingPattern: Array) { - this.snippingPattern = snippingPattern; - } - - transform: TransformerTransformCallback = async ( - chunk, - controller, - ) => { - this.buffer = Buffer.concat([this.buffer, chunk]); - while (true) { - const snipAmount = - this.snippingPattern[this.iteration % this.snippingPattern.length]; - if (snipAmount > this.buffer.length) break; - this.iteration += 1; - const returnBuffer = this.buffer.subarray(0, snipAmount); - controller.enqueue(returnBuffer); - this.buffer = this.buffer.subarray(snipAmount); - } - }; - - flush: TransformerFlushCallback = (controller) => { - controller.enqueue(this.buffer); - }; -} - /** * This is used to convert regular chunks into randomly sized chunks based on * a provided pattern. This is to replicate randomness introduced by packets * splitting up the data. */ -class BufferStreamToSnippedStream extends TransformStream { - constructor(snippingPattern: Array) { - super(new BufferStreamToSnipped(snippingPattern)); - } -} - -class BufferStreamToNoisy implements Transformer { - protected iteration = 0; - protected noise: Array; - - constructor(noise: Array) { - this.noise = noise; - } - - transform: TransformerTransformCallback = async ( - chunk, - controller, - ) => { - const noiseBuffer = this.noise[this.iteration % this.noise.length]; - const newBuffer = Buffer.from(Buffer.concat([chunk, noiseBuffer])); - controller.enqueue(newBuffer); - this.iteration += 1; - }; +function binaryStreamToSnippedStream(snippingPattern: Array) { + let buffer = Buffer.alloc(0); + let iteration = 0; + return new TransformStream({ + transform: (chunk, controller) => { + buffer = Buffer.concat([buffer, chunk]); + while (true) { + const snipAmount = snippingPattern[iteration % snippingPattern.length]; + if (snipAmount > buffer.length) break; + iteration += 1; + const returnBuffer = buffer.subarray(0, snipAmount); + controller.enqueue(returnBuffer); + buffer = buffer.subarray(snipAmount); + } + }, + flush: (controller) => { + controller.enqueue(buffer); + }, + }); } /** @@ -85,13 +46,23 @@ class BufferStreamToNoisy implements Transformer { * a provided pattern. This is to replicate randomness introduced by packets * splitting up the data. */ -class BufferStreamToNoisyStream extends TransformStream { - constructor(noise: Array) { - super(new BufferStreamToNoisy(noise)); - } +function binaryStreamToNoisyStream(noise: Array) { + let iteration: number = 0; + return new TransformStream({ + transform: (chunk, controller) => { + const noiseBuffer = noise[iteration % noise.length]; + const newBuffer = Buffer.from(Buffer.concat([chunk, noiseBuffer])); + controller.enqueue(newBuffer); + iteration += 1; + }, + }); } -const jsonRpcStream = (messages: Array) => { +/** + * This takes an array of JsonRpcMessages and converts it to a readable stream. + * Used to seed input for handlers and output for callers. + */ +const messagesToReadableStream = (messages: Array) => { return new ReadableStream({ async start(controller) { for (const arrayElement of messages) { @@ -103,6 +74,11 @@ const jsonRpcStream = (messages: Array) => { }); }; +/** + * Out RPC data is in form of JSON objects. + * This creates a JSON object of the type `JSONValue` and will be unchanged by + * a json stringify and parse cycle. + */ const safeJsonValueArb = fc .jsonValue() .map((value) => JSON.parse(JSON.stringify(value)) as JSONValue); @@ -125,7 +101,7 @@ const jsonRpcRequestMessageArb = ( requiredKeys: ['jsonrpc', 'method', 'id'], }, ) - .noShrink() as fc.Arbitrary>; + .noShrink() as fc.Arbitrary; const jsonRpcRequestNotificationArb = ( method: fc.Arbitrary = fc.string(), @@ -232,20 +208,6 @@ function streamToArray(): [Promise>, WritableStream] { return [result.p, outputStream]; } -class TapTransformer implements Transformer { - protected iteration = 0; - - constructor( - protected tapCallback: (chunk: I, iteration: number) => Promise, - ) {} - - transform: TransformerTransformCallback = async (chunk, controller) => { - await this.tapCallback(chunk, this.iteration); - controller.enqueue(chunk); - this.iteration += 1; - }; -} - type TapCallback = (chunk: T, iteration: number) => Promise; /** @@ -253,18 +215,27 @@ type TapCallback = (chunk: T, iteration: number) => Promise; * a provided pattern. This is to replicate randomness introduced by packets * splitting up the data. */ -class TapTransformerStream extends TransformStream { - constructor(tapCallback: TapCallback = async () => {}) { - super(new TapTransformer(tapCallback)); - } +function tapTransformStream(tapCallback: TapCallback = async () => {}) { + let iteration: number = 0; + return new TransformStream({ + transform: async (chunk, controller) => { + try { + await tapCallback(chunk, iteration); + } catch (e) { + // Ignore errors here + } + controller.enqueue(chunk); + iteration += 1; + }, + }); } function createTapPairs( forwardTapCallback: TapCallback = async () => {}, reverseTapCallback: TapCallback = async () => {}, ) { - const forwardTap = new TapTransformerStream(forwardTapCallback); - const reverseTap = new TapTransformerStream(reverseTapCallback); + const forwardTap = tapTransformStream(forwardTapCallback); + const reverseTap = tapTransformStream(reverseTapCallback); const clientPair: ReadableWritablePair = { readable: reverseTap.readable, writable: forwardTap.writable, @@ -280,9 +251,9 @@ function createTapPairs( } export { - BufferStreamToSnippedStream, - BufferStreamToNoisyStream, - jsonRpcStream, + binaryStreamToSnippedStream, + binaryStreamToNoisyStream, + messagesToReadableStream, safeJsonValueArb, jsonRpcRequestMessageArb, jsonRpcRequestNotificationArb, @@ -296,6 +267,6 @@ export { jsonMessagesArb, rawDataArb, streamToArray, - TapTransformerStream, + tapTransformStream, createTapPairs, }; diff --git a/tests/clientRPC/websocket.test.ts b/tests/clientRPC/websocket.test.ts index 9b1c64380..3f3324e7e 100644 --- a/tests/clientRPC/websocket.test.ts +++ b/tests/clientRPC/websocket.test.ts @@ -2,7 +2,6 @@ import type { TLSConfig } from '@/network/types'; import type { Server } from 'https'; import type { WebSocketServer } from 'ws'; import type { ManifestItem } from '@/RPC/types'; -import type { JSONValue } from '@/types'; import fs from 'fs'; import path from 'path'; import os from 'os'; @@ -27,6 +26,8 @@ describe('websocket', () => { let wss: WebSocketServer; const host = '127.0.0.1'; let port: number; + let rpcServer: RPCServer; + let rpcClient_: RPCClient; beforeEach(async () => { dataDir = await fs.promises.mkdtemp( @@ -46,6 +47,8 @@ describe('websocket', () => { port = await clientRPCUtils.listen(server, host); }); afterEach(async () => { + await rpcClient_?.destroy(); + await rpcServer?.destroy(); wss?.close(); server.close(); await keyRing.stop(); @@ -54,13 +57,13 @@ describe('websocket', () => { test('websocket should work with RPC', async () => { // Setting up server - const test1: ManifestItem = { + const test1: ManifestItem = { type: 'UNARY', handler: async (params, _container, _connectionInfo) => { return params; }, }; - const test2: ManifestItem = { + const test2: ManifestItem = { type: 'UNARY', handler: async () => { return { hello: 'not world' }; @@ -70,7 +73,7 @@ describe('websocket', () => { test1, test2, }; - const rpcServer = await RPCServer.createRPCServer({ + rpcServer = await RPCServer.createRPCServer({ manifest, container: {}, logger: logger.getChild('RPCServer'), @@ -93,6 +96,7 @@ describe('websocket', () => { ); }, }); + rpcClient_ = rpcClient; // Making the call await expect( From 282913eccce74b6d862620c4731a315f493d6b38 Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Thu, 9 Feb 2023 14:57:23 +1100 Subject: [PATCH 33/44] feat: class based client and server manifests - Related #500 - Related #501 [ci skip] --- src/RPC/RPCClient.ts | 20 +- src/RPC/RPCServer.ts | 112 +++++----- src/RPC/callers.ts | 53 +++++ src/RPC/handlers.ts | 67 ++++++ src/RPC/types.ts | 181 +++++++--------- src/RPC/utils.ts | 7 +- src/clientRPC/handlers/agentStatus.ts | 38 ++-- src/clientRPC/handlers/agentUnlock.ts | 33 +-- tests/RPC/RPC.test.ts | 143 ++++++------ tests/RPC/RPCClient.test.ts | 74 ++----- tests/RPC/RPCServer.test.ts | 216 +++++++++---------- tests/clientRPC/handlers/agentStatus.test.ts | 23 +- tests/clientRPC/handlers/agentUnlock.test.ts | 19 +- tests/clientRPC/websocket.test.ts | 39 ++-- 14 files changed, 543 insertions(+), 482 deletions(-) create mode 100644 src/RPC/callers.ts create mode 100644 src/RPC/handlers.ts diff --git a/src/RPC/RPCClient.ts b/src/RPC/RPCClient.ts index b8f6b4276..ee4846b73 100644 --- a/src/RPC/RPCClient.ts +++ b/src/RPC/RPCClient.ts @@ -1,9 +1,9 @@ import type { HandlerType, JsonRpcRequestMessage, - Manifest, - MapWithHandlers, StreamPairCreateCallback, + ClientManifest, + MapWithCallers, } from './types'; import type { JSONValue } from 'types'; import type { @@ -15,7 +15,7 @@ import type { JsonRpcRequest, JsonRpcResponse, MiddlewareFactory, - MapHandlers, + MapCallers, } from './types'; import { CreateDestroy, ready } from '@matrixai/async-init/dist/CreateDestroy'; import Logger from '@matrixai/logger'; @@ -23,10 +23,10 @@ import * as rpcErrors from './errors'; import * as rpcUtils from './utils'; // eslint-disable-next-line -interface RPCClient extends CreateDestroy {} +interface RPCClient extends CreateDestroy {} @CreateDestroy() -class RPCClient { - static async createRPCClient({ +class RPCClient { + static async createRPCClient({ manifest, streamPairCreateCallback, logger = new Logger(this.name), @@ -113,13 +113,13 @@ class RPCClient { } @ready(new rpcErrors.ErrorRpcDestroyed()) - public get methods(): MapHandlers { - return this.methodsProxy as MapHandlers; + public get methods(): MapCallers { + return this.methodsProxy as MapCallers; } @ready(new rpcErrors.ErrorRpcDestroyed()) - public get withMethods(): MapWithHandlers { - return this.withMethodsProxy as MapWithHandlers; + public get withMethods(): MapWithCallers { + return this.withMethodsProxy as MapWithCallers; } @ready(new rpcErrors.ErrorRpcDestroyed()) diff --git a/src/RPC/RPCServer.ts b/src/RPC/RPCServer.ts index d37ba0d66..78713736f 100644 --- a/src/RPC/RPCServer.ts +++ b/src/RPC/RPCServer.ts @@ -1,18 +1,18 @@ import type { - ClientStreamHandler, - DuplexStreamHandler, + ClientHandlerImplementation, + DuplexHandlerImplementation, JsonRpcError, JsonRpcRequest, JsonRpcResponse, JsonRpcResponseError, JsonRpcResponseResult, - Manifest, - RawDuplexStreamHandler, - ServerStreamHandler, - UnaryHandler, + ServerManifest, + RawHandlerImplementation, + ServerHandlerImplementation, + UnaryHandlerImplementation, } from './types'; import type { ReadableWritablePair } from 'stream/web'; -import type { JSONValue, POJO } from '../types'; +import type { JSONValue } from '../types'; import type { ConnectionInfo } from '../network/types'; import type { RPCErrorEvent } from './utils'; import type { MiddlewareFactory } from './types'; @@ -20,6 +20,13 @@ import { ReadableStream } from 'stream/web'; import { CreateDestroy, ready } from '@matrixai/async-init/dist/CreateDestroy'; import Logger from '@matrixai/logger'; import { PromiseCancellable } from '@matrixai/async-cancellable'; +import { + ClientHandler, + DuplexHandler, + RawHandler, + ServerHandler, + UnaryHandler, +} from './handlers'; import * as rpcUtils from './utils'; import * as rpcErrors from './errors'; import { never } from '../utils/utils'; @@ -30,12 +37,10 @@ interface RPCServer extends CreateDestroy {} class RPCServer { static async createRPCServer({ manifest, - container, middleware = rpcUtils.defaultMiddlewareWrapper(), logger = new Logger(this.name), }: { - manifest: Manifest; - container: POJO; + manifest: ServerManifest; middleware?: MiddlewareFactory< JsonRpcRequest, Uint8Array, @@ -47,7 +52,6 @@ class RPCServer { logger.info(`Creating ${this.name}`); const rpcServer = new this({ manifest, - container, middleware, logger, }); @@ -56,9 +60,8 @@ class RPCServer { } // Properties - protected container: POJO; protected logger: Logger; - protected handlerMap: Map = new Map(); + protected handlerMap: Map = new Map(); protected activeStreams: Set> = new Set(); protected events: EventTarget = new EventTarget(); protected middleware: MiddlewareFactory< @@ -70,12 +73,10 @@ class RPCServer { public constructor({ manifest, - container, middleware, logger, }: { - manifest: Manifest; - container: POJO; + manifest: ServerManifest; middleware: MiddlewareFactory< JsonRpcRequest, Uint8Array, @@ -85,27 +86,32 @@ class RPCServer { logger: Logger; }) { for (const [key, manifestItem] of Object.entries(manifest)) { - switch (manifestItem.type) { - case 'RAW': - this.registerRawStreamHandler(key, manifestItem.handler); - continue; - case 'DUPLEX': - this.registerDuplexStreamHandler(key, manifestItem.handler); - continue; - case 'SERVER': - this.registerServerStreamHandler(key, manifestItem.handler); - continue; - case 'CLIENT': - this.registerClientStreamHandler(key, manifestItem.handler); - continue; - case 'UNARY': - this.registerUnaryHandler(key, manifestItem.handler); - continue; - default: - never(); + if (manifestItem instanceof RawHandler) { + this.registerRawStreamHandler(key, manifestItem.handle); + continue; + } + if (manifestItem instanceof DuplexHandler) { + this.registerDuplexStreamHandler(key, manifestItem.handle); + continue; + } + if (manifestItem instanceof ServerHandler) { + this.registerServerStreamHandler(key, manifestItem.handle); + continue; + } + if (manifestItem instanceof ClientHandler) { + this.registerClientStreamHandler(key, manifestItem.handle); + continue; + } + if (manifestItem instanceof ClientHandler) { + this.registerClientStreamHandler(key, manifestItem.handle); + continue; + } + if (manifestItem instanceof UnaryHandler) { + this.registerUnaryHandler(key, manifestItem.handle); + continue; } + never(); } - this.container = container; this.middleware = middleware; this.logger = logger; } @@ -124,7 +130,7 @@ class RPCServer { protected registerRawStreamHandler( method: string, - handler: RawDuplexStreamHandler, + handler: RawHandlerImplementation, ) { this.handlerMap.set(method, handler); } @@ -132,13 +138,12 @@ class RPCServer { protected registerDuplexStreamHandler< I extends JSONValue, O extends JSONValue, - >(method: string, handler: DuplexStreamHandler) { + >(method: string, handler: DuplexHandlerImplementation) { // This needs to handle all the message parsing and conversion from // generators to the raw streams. - const rawSteamHandler: RawDuplexStreamHandler = ( + const rawSteamHandler: RawHandlerImplementation = ( [input, header], - container, connectionInfo, ctx, ) => { @@ -154,12 +159,7 @@ class RPCServer { yield data.params as I; } }; - for await (const response of handler( - dataGen(), - container, - connectionInfo, - ctx, - )) { + for await (const response of handler(dataGen(), connectionInfo, ctx)) { const responseMessage: JsonRpcResponseResult = { jsonrpc: '2.0', result: response, @@ -222,16 +222,15 @@ class RPCServer { protected registerUnaryHandler( method: string, - handler: UnaryHandler, + handler: UnaryHandlerImplementation, ) { - const wrapperDuplex: DuplexStreamHandler = async function* ( + const wrapperDuplex: DuplexHandlerImplementation = async function* ( input, - container, connectionInfo, ctx, ) { for await (const inputVal of input) { - yield handler(inputVal, container, connectionInfo, ctx); + yield handler(inputVal, connectionInfo, ctx); break; } }; @@ -241,15 +240,14 @@ class RPCServer { protected registerServerStreamHandler< I extends JSONValue, O extends JSONValue, - >(method: string, handler: ServerStreamHandler) { - const wrapperDuplex: DuplexStreamHandler = async function* ( + >(method: string, handler: ServerHandlerImplementation) { + const wrapperDuplex: DuplexHandlerImplementation = async function* ( input, - container, connectionInfo, ctx, ) { for await (const inputVal of input) { - yield* handler(inputVal, container, connectionInfo, ctx); + yield* handler(inputVal, connectionInfo, ctx); break; } }; @@ -259,14 +257,13 @@ class RPCServer { protected registerClientStreamHandler< I extends JSONValue, O extends JSONValue, - >(method: string, handler: ClientStreamHandler) { - const wrapperDuplex: DuplexStreamHandler = async function* ( + >(method: string, handler: ClientHandlerImplementation) { + const wrapperDuplex: DuplexHandlerImplementation = async function* ( input, - container, connectionInfo, ctx, ) { - yield handler(input, container, connectionInfo, ctx); + yield handler(input, connectionInfo, ctx); }; this.registerDuplexStreamHandler(method, wrapperDuplex); } @@ -321,7 +318,6 @@ class RPCServer { } const outputStream = handler( [inputStream, leadingMetadataMessage], - this.container, connectionInfo, { signal: abortController.signal }, ); diff --git a/src/RPC/callers.ts b/src/RPC/callers.ts new file mode 100644 index 000000000..6dcfbd930 --- /dev/null +++ b/src/RPC/callers.ts @@ -0,0 +1,53 @@ +import type { JSONValue } from 'types'; +import type { HandlerType } from './types'; + +abstract class Caller< + Input extends JSONValue = JSONValue, + Output extends JSONValue = JSONValue, +> { + protected _inputType: Input; + protected _outputType: Output; + // Need this to distinguish the classes when inferring types + abstract type: HandlerType; +} + +class RawCaller extends Caller { + public type: 'RAW' = 'RAW' as const; +} + +class DuplexCaller< + Input extends JSONValue = JSONValue, + Output extends JSONValue = JSONValue, +> extends Caller { + public type: 'DUPLEX' = 'DUPLEX' as const; +} + +class ServerCaller< + Input extends JSONValue = JSONValue, + Output extends JSONValue = JSONValue, +> extends Caller { + public type: 'SERVER' = 'SERVER' as const; +} + +class ClientCaller< + Input extends JSONValue = JSONValue, + Output extends JSONValue = JSONValue, +> extends Caller { + public type: 'CLIENT' = 'CLIENT' as const; +} + +class UnaryCaller< + Input extends JSONValue = JSONValue, + Output extends JSONValue = JSONValue, +> extends Caller { + public type: 'UNARY' = 'UNARY' as const; +} + +export { + Caller, + RawCaller, + DuplexCaller, + ServerCaller, + ClientCaller, + UnaryCaller, +}; diff --git a/src/RPC/handlers.ts b/src/RPC/handlers.ts new file mode 100644 index 000000000..86d1ba149 --- /dev/null +++ b/src/RPC/handlers.ts @@ -0,0 +1,67 @@ +import type { JSONValue } from 'types'; +import type { + ClientHandlerImplementation, + DuplexHandlerImplementation, + RawHandlerImplementation, + ServerHandlerImplementation, + UnaryHandlerImplementation, + ContainerType, +} from 'RPC/types'; + +abstract class Handler< + Container extends ContainerType = ContainerType, + Input extends JSONValue = JSONValue, + Output extends JSONValue = JSONValue, +> { + protected _inputType: Input; + protected _outputType: Output; + + constructor(protected container: Container) {} +} + +abstract class RawHandler< + Container extends ContainerType = ContainerType, +> extends Handler { + abstract handle: RawHandlerImplementation; +} + +abstract class DuplexHandler< + Container extends ContainerType = ContainerType, + Input extends JSONValue = JSONValue, + Output extends JSONValue = JSONValue, +> extends Handler { + abstract handle: DuplexHandlerImplementation; +} + +abstract class ServerHandler< + Container extends ContainerType = ContainerType, + Input extends JSONValue = JSONValue, + Output extends JSONValue = JSONValue, +> extends Handler { + abstract handle: ServerHandlerImplementation; +} + +abstract class ClientHandler< + Container extends ContainerType = ContainerType, + Input extends JSONValue = JSONValue, + Output extends JSONValue = JSONValue, +> extends Handler { + abstract handle: ClientHandlerImplementation; +} + +abstract class UnaryHandler< + Container extends ContainerType = ContainerType, + Input extends JSONValue = JSONValue, + Output extends JSONValue = JSONValue, +> extends Handler { + abstract handle: UnaryHandlerImplementation; +} + +export { + Handler, + RawHandler, + DuplexHandler, + ServerHandler, + ClientHandler, + UnaryHandler, +}; diff --git a/src/RPC/types.ts b/src/RPC/types.ts index 8d1d6e27b..f4e2a03ff 100644 --- a/src/RPC/types.ts +++ b/src/RPC/types.ts @@ -1,4 +1,4 @@ -import type { JSONValue, POJO } from '../types'; +import type { JSONValue } from '../types'; import type { ConnectionInfo } from '../network/types'; import type { ContextCancellable } from '../contexts/types'; import type { @@ -6,6 +6,15 @@ import type { ReadableWritablePair, WritableStream, } from 'stream/web'; +import type { Handler } from './handlers'; +import type { + Caller, + RawCaller, + DuplexCaller, + ServerCaller, + ClientCaller, + UnaryCaller, +} from './callers'; /** * This is the JSON RPC request object. this is the generic message type used for the RPC. @@ -104,32 +113,33 @@ type JsonRpcMessage = | JsonRpcResponse; // Handler types -type Handler = ( +type HandlerImplementation = ( input: I, - container: POJO, connectionInfo: ConnectionInfo, ctx: ContextCancellable, ) => O; -type RawDuplexStreamHandler = Handler< +type RawHandlerImplementation = HandlerImplementation< [ReadableStream, JsonRpcRequest], ReadableStream >; -type DuplexStreamHandler< +type DuplexHandlerImplementation< I extends JSONValue = JSONValue, O extends JSONValue = JSONValue, -> = Handler, AsyncGenerator>; -type ServerStreamHandler< +> = HandlerImplementation, AsyncGenerator>; +type ServerHandlerImplementation< I extends JSONValue = JSONValue, O extends JSONValue = JSONValue, -> = Handler>; -type ClientStreamHandler< +> = HandlerImplementation>; +type ClientHandlerImplementation< I extends JSONValue = JSONValue, O extends JSONValue = JSONValue, -> = Handler, Promise>; -type UnaryHandler< +> = HandlerImplementation, Promise>; +type UnaryHandlerImplementation< I extends JSONValue = JSONValue, O extends JSONValue = JSONValue, -> = Handler>; +> = HandlerImplementation>; + +type ContainerType = Record; type StreamPairCreateCallback = () => Promise< ReadableWritablePair @@ -140,17 +150,21 @@ type MiddlewareFactory = (header?: JsonRpcRequest) => { reverse: ReadableWritablePair; }; -type DuplexStreamCaller< +type RawCallerImplementation = ( + params: JSONValue, +) => Promise>; + +type DuplexCallerImplementation< I extends JSONValue = JSONValue, O extends JSONValue = JSONValue, > = () => Promise>; -type ServerStreamCaller< +type ServerCallerImplementation< I extends JSONValue = JSONValue, O extends JSONValue = JSONValue, > = (parameters: I) => Promise>; -type ClientStreamCaller< +type ClientCallerImplementation< I extends JSONValue = JSONValue, O extends JSONValue = JSONValue, > = () => Promise<{ @@ -158,58 +172,45 @@ type ClientStreamCaller< writable: WritableStream; }>; -type UnaryCaller< +type UnaryCallerImplementation< I extends JSONValue = JSONValue, O extends JSONValue = JSONValue, > = (parameters: I) => Promise; -type RawStreamCaller = ( - params: JSONValue, -) => Promise>; - -type ConvertDuplexStreamHandler = T extends DuplexStreamHandler< - infer I, - infer O -> - ? DuplexStreamCaller +type ConvertDuplexCaller = T extends DuplexCaller + ? DuplexCallerImplementation : never; -type ConvertServerStreamHandler = T extends ServerStreamHandler< - infer I, - infer O -> - ? ServerStreamCaller +type ConvertServerCaller = T extends ServerCaller + ? ServerCallerImplementation : never; -type ConvertClientStreamHandler = T extends ClientStreamHandler< - infer I, - infer O -> - ? ClientStreamCaller +type ConvertClientCaller = T extends ClientCaller + ? ClientCallerImplementation : never; -type ConvertUnaryCaller = T extends UnaryHandler - ? UnaryCaller +type ConvertUnaryCaller = T extends UnaryCaller + ? UnaryCallerImplementation : never; -type ConvertHandler = T extends DuplexStreamHandler - ? ConvertDuplexStreamHandler - : T extends ServerStreamHandler - ? ConvertServerStreamHandler - : T extends ClientStreamHandler - ? ConvertClientStreamHandler - : T extends UnaryHandler +type ConvertCaller = T extends DuplexCaller + ? ConvertDuplexCaller + : T extends ServerCaller + ? ConvertServerCaller + : T extends ClientCaller + ? ConvertClientCaller + : T extends UnaryCaller ? ConvertUnaryCaller - : T extends RawDuplexStreamHandler - ? RawStreamCaller + : T extends RawCaller + ? RawCallerImplementation : never; -type WithDuplexStreamCaller< +type WithDuplexCallerImplementation< I extends JSONValue = JSONValue, O extends JSONValue = JSONValue, > = (f: (output: AsyncGenerator) => AsyncGenerator) => Promise; -type WithServerStreamCaller< +type WithServerCallerImplementation< I extends JSONValue = JSONValue, O extends JSONValue = JSONValue, > = ( @@ -217,82 +218,58 @@ type WithServerStreamCaller< f: (output: AsyncGenerator) => Promise, ) => Promise; -type WithClientStreamCaller< +type WithClientCallerImplementation< I extends JSONValue = JSONValue, O extends JSONValue = JSONValue, > = (f: () => AsyncGenerator) => Promise; -type WithRawStreamCaller = ( +type WithRawCallerImplementation = ( params: JSONValue, f: (output: AsyncGenerator) => AsyncGenerator, ) => Promise; -type ConvertWithDuplexStreamHandler = T extends DuplexStreamHandler< +type ConvertWithDuplexStreamHandler = T extends DuplexCaller< infer I, infer O > - ? WithDuplexStreamCaller + ? WithDuplexCallerImplementation : never; -type ConvertWithServerStreamHandler = T extends ServerStreamHandler< +type ConvertWithServerStreamHandler = T extends ServerCaller< infer I, infer O > - ? WithServerStreamCaller + ? WithServerCallerImplementation : never; -type ConvertWithClientStreamHandler = T extends ClientStreamHandler< +type ConvertWithClientStreamHandler = T extends ClientCaller< infer I, infer O > - ? WithClientStreamCaller + ? WithClientCallerImplementation : never; -type ConvertWithHandler = T extends DuplexStreamHandler +type ConvertWithHandler = T extends DuplexCaller ? ConvertWithDuplexStreamHandler - : T extends ServerStreamHandler + : T extends ServerCaller ? ConvertWithServerStreamHandler - : T extends ClientStreamHandler + : T extends ClientCaller ? ConvertWithClientStreamHandler - : T extends RawDuplexStreamHandler - ? WithRawStreamCaller + : T extends RawCaller + ? WithRawCallerImplementation : never; +type ServerManifest = Record; +type ClientManifest = Record; + type HandlerType = 'DUPLEX' | 'SERVER' | 'CLIENT' | 'UNARY' | 'RAW'; -type ManifestItem< - I extends JSONValue = JSONValue, - O extends JSONValue = JSONValue, -> = - | { - type: 'DUPLEX'; - handler: DuplexStreamHandler; - } - | { - type: 'SERVER'; - handler: ServerStreamHandler; - } - | { - type: 'CLIENT'; - handler: ClientStreamHandler; - } - | { - type: 'UNARY'; - handler: UnaryHandler; - } - | { - type: 'RAW'; - handler: RawDuplexStreamHandler; - }; - -type Manifest = Record; - -type MapHandlers = { - [K in keyof T]: ConvertHandler; +type MapCallers = { + [K in keyof T]: ConvertCaller; }; -type MapWithHandlers = { - [K in keyof T]: ConvertWithHandler; +type MapWithCallers = { + [K in keyof T]: ConvertWithHandler; }; export type { @@ -304,16 +281,18 @@ export type { JsonRpcRequest, JsonRpcResponse, JsonRpcMessage, - RawDuplexStreamHandler, - DuplexStreamHandler, - ServerStreamHandler, - ClientStreamHandler, - UnaryHandler, + HandlerImplementation, + RawHandlerImplementation, + DuplexHandlerImplementation, + ServerHandlerImplementation, + ClientHandlerImplementation, + UnaryHandlerImplementation, + ContainerType, StreamPairCreateCallback, MiddlewareFactory, + ServerManifest, + ClientManifest, HandlerType, - ManifestItem, - Manifest, - MapHandlers, - MapWithHandlers, + MapCallers, + MapWithCallers, }; diff --git a/src/RPC/utils.ts b/src/RPC/utils.ts index a5d87b143..dad91d132 100644 --- a/src/RPC/utils.ts +++ b/src/RPC/utils.ts @@ -8,9 +8,10 @@ import type { JsonRpcRequest, JsonRpcResponse, MiddlewareFactory, + HandlerType, + ClientManifest, } from 'RPC/types'; import type { JSONValue } from '../types'; -import type { HandlerType, Manifest } from 'RPC/types'; import { TransformStream } from 'stream/web'; import { AbstractError } from '@matrixai/errors'; import * as rpcErrors from './errors'; @@ -563,7 +564,9 @@ function extractFirstMessageTransform( return { headTransformStream, firstMessageProm: messageProm.p }; } -function getHandlerTypes(manifest: Manifest): Record { +function getHandlerTypes( + manifest: ClientManifest, +): Record { const out: Record = {}; for (const [k, v] of Object.entries(manifest)) { out[k] = v.type; diff --git a/src/clientRPC/handlers/agentStatus.ts b/src/clientRPC/handlers/agentStatus.ts index c708f16c7..1f84ffffa 100644 --- a/src/clientRPC/handlers/agentStatus.ts +++ b/src/clientRPC/handlers/agentStatus.ts @@ -1,10 +1,12 @@ -import type { ManifestItem } from '../../RPC/types'; +import type { UnaryHandlerImplementation } from '../../RPC/types'; import type KeyRing from '../../keys/KeyRing'; import type CertManager from '../../keys/CertManager'; import type Logger from '@matrixai/logger'; import type { NodeIdEncoded } from '../../ids'; import * as nodesUtils from '../../nodes/utils'; import * as keysUtils from '../../keys/utils'; +import { UnaryHandler } from '../../RPC/handlers'; +import { UnaryCaller } from '../../RPC/callers'; type StatusResult = { pid: number; @@ -12,26 +14,26 @@ type StatusResult = { publicJwk: string; }; -const agentStatus: ManifestItem = { - type: 'UNARY', - handler: async ( - input, - container: { - keyRing: KeyRing; - certManager: CertManager; - logger: Logger; - }, - _connectionInfo, - _ctx, - ) => { +const agentStatusCaller = new UnaryCaller(); + +class AgentStatusHandler extends UnaryHandler< + { + keyRing: KeyRing; + certManager: CertManager; + logger: Logger; + }, + null, + StatusResult +> { + public handle: UnaryHandlerImplementation = async () => { return { pid: process.pid, - nodeId: nodesUtils.encodeNodeId(container.keyRing.getNodeId()), + nodeId: nodesUtils.encodeNodeId(this.container.keyRing.getNodeId()), publicJwk: JSON.stringify( - keysUtils.publicKeyToJWK(container.keyRing.keyPair.publicKey), + keysUtils.publicKeyToJWK(this.container.keyRing.keyPair.publicKey), ), }; - }, -}; + }; +} -export { agentStatus }; +export { AgentStatusHandler, agentStatusCaller }; diff --git a/src/clientRPC/handlers/agentUnlock.ts b/src/clientRPC/handlers/agentUnlock.ts index 47c804024..bebd5ba08 100644 --- a/src/clientRPC/handlers/agentUnlock.ts +++ b/src/clientRPC/handlers/agentUnlock.ts @@ -1,20 +1,23 @@ -import type { ManifestItem } from '../../RPC/types'; +import type { UnaryHandlerImplementation } from '../../RPC/types'; import type Logger from '@matrixai/logger'; import type { ClientDataAndMetadata } from '../types'; +import { UnaryHandler } from '../../RPC/handlers'; +import { UnaryCaller } from '../../RPC/callers'; -const agentUnlock: ManifestItem< +const agentUnlockCaller = new UnaryCaller< ClientDataAndMetadata, ClientDataAndMetadata -> = { - type: 'UNARY', - handler: async ( - _input, - _container: { - logger: Logger; - }, - _connectionInfo, - _ctx, - ) => { +>(); + +class AgentUnlockHandler extends UnaryHandler< + { logger: Logger }, + ClientDataAndMetadata, + ClientDataAndMetadata +> { + public handle: UnaryHandlerImplementation< + ClientDataAndMetadata, + ClientDataAndMetadata + > = async () => { // This is a NOP handler, // authentication and unlocking is handled via middleware. // Failure to authenticate will be an error from the middleware layer. @@ -22,7 +25,7 @@ const agentUnlock: ManifestItem< metadata: {}, data: null, }; - }, -}; + }; +} -export { agentUnlock }; +export { agentUnlockCaller, AgentUnlockHandler }; diff --git a/tests/RPC/RPC.test.ts b/tests/RPC/RPC.test.ts index 37516a630..0dc9e8c37 100644 --- a/tests/RPC/RPC.test.ts +++ b/tests/RPC/RPC.test.ts @@ -1,9 +1,31 @@ -import type { JsonRpcRequest, ManifestItem } from '@/RPC/types'; +import type { + ClientHandlerImplementation, + ContainerType, + DuplexHandlerImplementation, + JsonRpcRequest, + RawHandlerImplementation, + ServerHandlerImplementation, + UnaryHandlerImplementation, +} from '@/RPC/types'; import type { ConnectionInfo } from '@/network/types'; import { fc, testProp } from '@fast-check/jest'; import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; import RPCServer from '@/RPC/RPCServer'; import RPCClient from '@/RPC/RPCClient'; +import { + ClientHandler, + DuplexHandler, + RawHandler, + ServerHandler, + UnaryHandler, +} from '@/RPC/handlers'; +import { + ClientCaller, + DuplexCaller, + RawCaller, + ServerCaller, + UnaryCaller, +} from '@/RPC/callers'; import * as rpcTestUtils from './utils'; describe('RPC', () => { @@ -21,26 +43,24 @@ describe('RPC', () => { >(); let header: JsonRpcRequest | undefined; - const testMethod: ManifestItem = { - type: 'RAW', - handler: ([input, header_], _container, _connectionInfo, _ctx) => { + class TestMethod extends RawHandler { + public handle: RawHandlerImplementation = ([input, header_]) => { header = header_; return input; - }, - }; - const manifest = { - testMethod, - }; - const container = {}; + }; + } const rpcServer = await RPCServer.createRPCServer({ - manifest, - container, + manifest: { + testMethod: new TestMethod({}), + }, logger, }); rpcServer.handleStream(serverPair, {} as ConnectionInfo); const rpcClient = await RPCClient.createRPCClient({ - manifest, + manifest: { + testMethod: new RawCaller(), + }, streamPairCreateCallback: async () => clientPair, logger, }); @@ -75,28 +95,25 @@ describe('RPC', () => { Uint8Array, Uint8Array >(); - - const testMethod: ManifestItem = { - type: 'DUPLEX', - handler: async function* (input, _container, _connectionInfo, _ctx) { + class TestMethod extends DuplexHandler { + public handle: DuplexHandlerImplementation = async function* (input) { for await (const val of input) { yield val; } - }, - }; - const manifest = { - testMethod, - }; - const container = {}; + }; + } const rpcServer = await RPCServer.createRPCServer({ - manifest, - container, + manifest: { + testMethod: new TestMethod({}), + }, logger, }); rpcServer.handleStream(serverPair, {} as ConnectionInfo); const rpcClient = await RPCClient.createRPCClient({ - manifest, + manifest: { + testMethod: new DuplexCaller(), + }, streamPairCreateCallback: async () => clientPair, logger, }); @@ -125,27 +142,27 @@ describe('RPC', () => { Uint8Array >(); - const testMethod: ManifestItem = { - type: 'SERVER', - handler: async function* (input, _container, _connectionInfo, _ctx) { - for (let i = 0; i < input; i++) { - yield i; - } - }, - }; - const manifest = { - testMethod, - }; - const container = {}; + class TestMethod extends ServerHandler { + public handle: ServerHandlerImplementation = + async function* (input) { + for (let i = 0; i < input; i++) { + yield i; + } + }; + } + const rpcServer = await RPCServer.createRPCServer({ - manifest, - container, + manifest: { + testMethod: new TestMethod({}), + }, logger, }); rpcServer.handleStream(serverPair, {} as ConnectionInfo); const rpcClient = await RPCClient.createRPCClient({ - manifest, + manifest: { + testMethod: new ServerCaller(), + }, streamPairCreateCallback: async () => clientPair, logger, }); @@ -170,29 +187,29 @@ describe('RPC', () => { Uint8Array >(); - const testMethod: ManifestItem = { - type: 'CLIENT', - handler: async (input) => { + class TestMethod extends ClientHandler { + public handle: ClientHandlerImplementation = async ( + input, + ) => { let acc = 0; for await (const number of input) { acc += number; } return acc; - }, - }; - const manifest = { - testMethod, - }; - const container = {}; + }; + } const rpcServer = await RPCServer.createRPCServer({ - manifest, - container, + manifest: { + testMethod: new TestMethod({}), + }, logger, }); rpcServer.handleStream(serverPair, {} as ConnectionInfo); const rpcClient = await RPCClient.createRPCClient({ - manifest, + manifest: { + testMethod: new ClientCaller(), + }, streamPairCreateCallback: async () => clientPair, logger, }); @@ -219,23 +236,21 @@ describe('RPC', () => { Uint8Array >(); - const testMethod: ManifestItem = { - type: 'UNARY', - handler: async (input) => input, - }; - const manifest = { - testMethod, - }; - const container = {}; + class TestMethod extends UnaryHandler { + public handle: UnaryHandlerImplementation = async (input) => input; + } const rpcServer = await RPCServer.createRPCServer({ - manifest, - container, + manifest: { + testMethod: new TestMethod({}), + }, logger, }); rpcServer.handleStream(serverPair, {} as ConnectionInfo); const rpcClient = await RPCClient.createRPCClient({ - manifest, + manifest: { + testMethod: new UnaryCaller(), + }, streamPairCreateCallback: async () => clientPair, logger, }); diff --git a/tests/RPC/RPCClient.test.ts b/tests/RPC/RPCClient.test.ts index c5fd9212b..3af073353 100644 --- a/tests/RPC/RPCClient.test.ts +++ b/tests/RPC/RPCClient.test.ts @@ -4,7 +4,6 @@ import type { JsonRpcRequest, JsonRpcRequestMessage, JsonRpcResponse, - ManifestItem, } from '@/RPC/types'; import { TransformStream, ReadableStream } from 'stream/web'; import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; @@ -12,6 +11,13 @@ import { testProp, fc } from '@fast-check/jest'; import RPCClient from '@/RPC/RPCClient'; import RPCServer from '@/RPC/RPCServer'; import * as rpcErrors from '@/RPC/errors'; +import { + ClientCaller, + DuplexCaller, + RawCaller, + ServerCaller, + UnaryCaller, +} from '@/RPC/callers'; import * as rpcTestUtils from './utils'; describe(`${RPCClient.name}`, () => { @@ -558,15 +564,9 @@ describe(`${RPCClient.name}`, () => { readable: inputStream, writable: outputStream, }; - const duplex: ManifestItem = { - type: 'DUPLEX', - handler: async function* (input) { - yield* input; - }, - }; const rpcClient = await RPCClient.createRPCClient({ manifest: { - duplex, + duplex: new DuplexCaller(), }, streamPairCreateCallback: async () => streamPair, logger, @@ -612,15 +612,9 @@ describe(`${RPCClient.name}`, () => { readable: inputStream, writable: outputStream, }; - const server: ManifestItem = { - type: 'SERVER', - handler: async function* (input) { - yield* input; - }, - }; const rpcClient = await RPCClient.createRPCClient({ manifest: { - server, + server: new ServerCaller(), }, streamPairCreateCallback: async () => streamPair, logger, @@ -657,15 +651,9 @@ describe(`${RPCClient.name}`, () => { readable: inputStream, writable: outputStream, }; - const client: ManifestItem = { - type: 'CLIENT', - handler: async (_) => { - return 'hello'; - }, - }; const rpcClient = await RPCClient.createRPCClient({ manifest: { - client, + client: new ClientCaller(), }, streamPairCreateCallback: async () => streamPair, logger, @@ -701,13 +689,9 @@ describe(`${RPCClient.name}`, () => { readable: inputStream, writable: outputStream, }; - const unary: ManifestItem = { - type: 'UNARY', - handler: async (input) => input, - }; const rpcClient = await RPCClient.createRPCClient({ manifest: { - unary, + unary: new UnaryCaller(), }, streamPairCreateCallback: async () => streamPair, logger, @@ -748,13 +732,9 @@ describe(`${RPCClient.name}`, () => { }), writable: inputWritableStream, }; - const raw: ManifestItem = { - type: 'RAW', - handler: ([input]) => input, - }; const rpcClient = await RPCClient.createRPCClient({ manifest: { - raw, + raw: new RawCaller(), }, streamPairCreateCallback: async () => streamPair, logger, @@ -795,15 +775,9 @@ describe(`${RPCClient.name}`, () => { readable: inputStream, writable: outputStream, }; - const duplex: ManifestItem = { - type: 'DUPLEX', - handler: async function* (input) { - yield* input; - }, - }; const rpcClient = await RPCClient.createRPCClient({ manifest: { - duplex, + duplex: new DuplexCaller(), }, streamPairCreateCallback: async () => streamPair, logger, @@ -836,15 +810,9 @@ describe(`${RPCClient.name}`, () => { readable: inputStream, writable: outputStream, }; - const server: ManifestItem = { - type: 'SERVER', - handler: async function* (input) { - yield input; - }, - }; const rpcClient = await RPCClient.createRPCClient({ manifest: { - server, + server: new ServerCaller(), }, streamPairCreateCallback: async () => streamPair, logger, @@ -880,15 +848,9 @@ describe(`${RPCClient.name}`, () => { readable: inputStream, writable: outputStream, }; - const client: ManifestItem = { - type: 'CLIENT', - handler: async (_) => { - return 'someValue'; - }, - }; const rpcClient = await RPCClient.createRPCClient({ manifest: { - client, + client: new ClientCaller(), }, streamPairCreateCallback: async () => streamPair, logger, @@ -934,13 +896,9 @@ describe(`${RPCClient.name}`, () => { }), writable: inputWritableStream, }; - const raw: ManifestItem = { - type: 'RAW', - handler: ([input]) => input, - }; const rpcClient = await RPCClient.createRPCClient({ manifest: { - raw, + raw: new RawCaller(), }, streamPairCreateCallback: async () => streamPair, logger, diff --git a/tests/RPC/RPCServer.test.ts b/tests/RPC/RPCServer.test.ts index d21215026..a613e2ec5 100644 --- a/tests/RPC/RPCServer.test.ts +++ b/tests/RPC/RPCServer.test.ts @@ -1,8 +1,13 @@ import type { + ClientHandlerImplementation, + ContainerType, + DuplexHandlerImplementation, JsonRpcRequest, JsonRpcResponse, JsonRpcResponseError, - ManifestItem, + RawHandlerImplementation, + ServerHandlerImplementation, + UnaryHandlerImplementation, } from '@/RPC/types'; import type { JSONValue } from '@/types'; import type { ConnectionInfo, Host, Port } from '@/network/types'; @@ -14,6 +19,13 @@ import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; import RPCServer from '@/RPC/RPCServer'; import * as rpcErrors from '@/RPC/errors'; import * as rpcUtils from '@/RPC/utils'; +import { + ClientHandler, + DuplexHandler, + RawHandler, + ServerHandler, + UnaryHandler, +} from '@/RPC/handlers'; import * as rpcTestUtils from './utils'; describe(`${RPCServer.name}`, () => { @@ -61,9 +73,8 @@ describe(`${RPCServer.name}`, () => { .pipeThrough( rpcTestUtils.binaryStreamToSnippedStream([4, 7, 13, 2, 6]), ); - const testMethod: ManifestItem = { - type: 'RAW', - handler: ([input]) => { + class TestHandler extends RawHandler { + public handle: RawHandlerImplementation = ([input]) => { void (async () => { for await (const _ of input) { // No touch, only consume @@ -75,14 +86,12 @@ describe(`${RPCServer.name}`, () => { controller.close(); }, }); - }, - }; - const container = {}; + }; + } const rpcServer = await RPCServer.createRPCServer({ manifest: { - testMethod, + testMethod: new TestHandler({}), }, - container, logger, }); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); @@ -101,21 +110,18 @@ describe(`${RPCServer.name}`, () => { [specificMessageArb], async (messages) => { const stream = rpcTestUtils.messagesToReadableStream(messages); - const testMethod: ManifestItem = { - type: 'DUPLEX', - handler: async function* (input, _container, _connectionInfo, _ctx) { + class TestMethod extends DuplexHandler { + public handle: DuplexHandlerImplementation = async function* (input) { for await (const val of input) { yield val; break; } - }, - }; - const container = {}; + }; + } const rpcServer = await RPCServer.createRPCServer({ manifest: { - testMethod, + testMethod: new TestMethod({}), }, - container, logger, }); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); @@ -133,22 +139,19 @@ describe(`${RPCServer.name}`, () => { [specificMessageArb], async (messages) => { const stream = rpcTestUtils.messagesToReadableStream(messages); - const testMethod: ManifestItem = { - type: 'CLIENT', - handler: async function (input, _container, _connectionInfo, _ctx) { + class TestMethod extends ClientHandler { + public handle: ClientHandlerImplementation = async function (input) { let count = 0; for await (const _ of input) { count += 1; } return count; - }, - }; - const container = {}; + }; + } const rpcServer = await RPCServer.createRPCServer({ manifest: { - testMethod, + testMethod: new TestMethod({}), }, - container, logger, }); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); @@ -166,20 +169,18 @@ describe(`${RPCServer.name}`, () => { [singleNumberMessageArb], async (messages) => { const stream = rpcTestUtils.messagesToReadableStream(messages); - const testMethod: ManifestItem = { - type: 'SERVER', - handler: async function* (input, _container, _connectionInfo, _ctx) { - for (let i = 0; i < input; i++) { - yield i; - } - }, - }; - const container = {}; + class TestMethod extends ServerHandler { + public handle: ServerHandlerImplementation = + async function* (input) { + for (let i = 0; i < input; i++) { + yield i; + } + }; + } const rpcServer = await RPCServer.createRPCServer({ manifest: { - testMethod, + testMethod: new TestMethod({}), }, - container, logger, }); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); @@ -197,16 +198,13 @@ describe(`${RPCServer.name}`, () => { [specificMessageArb], async (messages) => { const stream = rpcTestUtils.messagesToReadableStream(messages); - const testMethod: ManifestItem = { - type: 'UNARY', - handler: async (input, _container, _connectionInfo, _ctx) => input, - }; - const container = {}; + class TestMethod extends UnaryHandler { + public handle: UnaryHandlerImplementation = async (input) => input; + } const rpcServer = await RPCServer.createRPCServer({ manifest: { - testMethod, + testMethod: new TestMethod({}), }, - container, logger, }); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); @@ -224,25 +222,24 @@ describe(`${RPCServer.name}`, () => { [specificMessageArb], async (messages) => { const stream = rpcTestUtils.messagesToReadableStream(messages); - const testMethod: ManifestItem = { - type: 'DUPLEX', - handler: async function* (input, container_, _connectionInfo, _ctx) { - expect(container_).toBe(container); - for await (const val of input) { - yield val; - } - }, - }; const container = { a: Symbol('a'), B: Symbol('b'), C: Symbol('c'), }; + class TestMethod extends DuplexHandler { + public handle: DuplexHandlerImplementation = async function* (input) { + expect(this.container).toBe(container); + for await (const val of input) { + yield val; + } + }; + } + const rpcServer = await RPCServer.createRPCServer({ manifest: { - testMethod, + testMethod: new TestMethod(container), }, - container, logger, }); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); @@ -269,21 +266,21 @@ describe(`${RPCServer.name}`, () => { remotePort: 12341 as Port, }; let handledConnectionInfo; - const testMethod: ManifestItem = { - type: 'DUPLEX', - handler: async function* (input, _container, connectionInfo_, _ctx) { + class TestMethod extends DuplexHandler { + public handle: DuplexHandlerImplementation = async function* ( + input, + connectionInfo_, + ) { handledConnectionInfo = connectionInfo_; for await (const val of input) { yield val; } - }, - }; - const container = {}; + }; + } const rpcServer = await RPCServer.createRPCServer({ manifest: { - testMethod, + testMethod: new TestMethod({}), }, - container, logger, }); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); @@ -301,21 +298,22 @@ describe(`${RPCServer.name}`, () => { // If I don't pipe the tap to the output we actually iterate over some data. testProp('Handler can be aborted', [specificMessageArb], async (messages) => { const stream = rpcTestUtils.messagesToReadableStream(messages); - const testMethod: ManifestItem = { - type: 'DUPLEX', - handler: async function* (input, _container, _connectionInfo, ctx) { + class TestMethod extends DuplexHandler { + public handle: DuplexHandlerImplementation = async function* ( + input, + _connectionInf, + ctx, + ) { for await (const val of input) { if (ctx.signal.aborted) throw ctx.signal.reason; yield val; } - }, - }; - const container = {}; + }; + } const rpcServer = await RPCServer.createRPCServer({ manifest: { - testMethod, + testMethod: new TestMethod({}), }, - container, logger, }); const [outputResult, outputStream] = @@ -351,20 +349,17 @@ describe(`${RPCServer.name}`, () => { }); testProp('Handler yields nothing', [specificMessageArb], async (messages) => { const stream = rpcTestUtils.messagesToReadableStream(messages); - const testMethod: ManifestItem = { - type: 'DUPLEX', - handler: async function* (input, _container, _connectionInfo, _ctx) { + class TestMethod extends DuplexHandler { + public handle: DuplexHandlerImplementation = async function* (input) { for await (const _ of input) { // Do nothing, just consume } - }, - }; - const container = {}; + }; + } const rpcServer = await RPCServer.createRPCServer({ manifest: { - testMethod, + testMethod: new TestMethod({}), }, - container, logger, }); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); @@ -382,18 +377,15 @@ describe(`${RPCServer.name}`, () => { [specificMessageArb, errorArb], async (messages, error) => { const stream = rpcTestUtils.messagesToReadableStream(messages); - const testMethod: ManifestItem = { - type: 'DUPLEX', - handler: async function* (_input, _container, _connectionInfo, _ctx) { + class TestMethod extends DuplexHandler { + public handle: DuplexHandlerImplementation = async function* () { throw error; - }, - }; - const container = {}; + }; + } const rpcServer = await RPCServer.createRPCServer({ manifest: { - testMethod, + testMethod: new TestMethod({}), }, - container, logger, }); let resolve, reject; @@ -423,18 +415,15 @@ describe(`${RPCServer.name}`, () => { [specificMessageArb], async (messages) => { const stream = rpcTestUtils.messagesToReadableStream(messages); - const testMethod: ManifestItem = { - type: 'DUPLEX', - handler: async function* (_input, _container, _connectionInfo, _ctx) { + class TestMethod extends DuplexHandler { + public handle: DuplexHandlerImplementation = async function* () { throw new rpcErrors.ErrorRpcPlaceholderConnectionError(); - }, - }; - const container = {}; + }; + } const rpcServer = await RPCServer.createRPCServer({ manifest: { - testMethod, + testMethod: new TestMethod({}), }, - container, logger, }); let resolve, reject; @@ -460,15 +449,13 @@ describe(`${RPCServer.name}`, () => { ); testProp('forward middlewares', [specificMessageArb], async (messages) => { const stream = rpcTestUtils.messagesToReadableStream(messages); - const testMethod: ManifestItem = { - type: 'DUPLEX', - handler: async function* (input, _container, _connectionInfo, _ctx) { + class TestMethod extends DuplexHandler { + public handle: DuplexHandlerImplementation = async function* (input) { for await (const val of input) { yield val; } - }, - }; - const container = {}; + }; + } const middleware = rpcUtils.defaultMiddlewareWrapper(() => { return { forward: new TransformStream({ @@ -482,10 +469,9 @@ describe(`${RPCServer.name}`, () => { }); const rpcServer = await RPCServer.createRPCServer({ manifest: { - testMethod, + testMethod: new TestMethod({}), }, middleware, - container, logger, }); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); @@ -508,15 +494,13 @@ describe(`${RPCServer.name}`, () => { }); testProp('reverse middlewares', [specificMessageArb], async (messages) => { const stream = rpcTestUtils.messagesToReadableStream(messages); - const testMethod: ManifestItem = { - type: 'DUPLEX', - handler: async function* (input, _container, _connectionInfo, _ctx) { + class TestMethod extends DuplexHandler { + public handle: DuplexHandlerImplementation = async function* (input) { for await (const val of input) { yield val; } - }, - }; - const container = {}; + }; + } const middleware = rpcUtils.defaultMiddlewareWrapper(() => { return { forward: new TransformStream(), @@ -530,10 +514,9 @@ describe(`${RPCServer.name}`, () => { }); const rpcServer = await RPCServer.createRPCServer({ manifest: { - testMethod, + testMethod: new TestMethod({}), }, middleware, - container, logger, }); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); @@ -559,15 +542,13 @@ describe(`${RPCServer.name}`, () => { [invalidTokenMessageArb], async (message) => { const stream = rpcTestUtils.messagesToReadableStream([message]); - const testMethod: ManifestItem = { - type: 'DUPLEX', - handler: async function* (input, _container, _connectionInfo, _ctx) { + class TestMethod extends DuplexHandler { + public handle: DuplexHandlerImplementation = async function* (input) { for await (const val of input) { yield val; } - }, - }; - const container = {}; + }; + } const middleware = rpcUtils.defaultMiddlewareWrapper(() => { let first = true; let reverseController: TransformStreamDefaultController; @@ -600,10 +581,9 @@ describe(`${RPCServer.name}`, () => { }); const rpcServer = await RPCServer.createRPCServer({ manifest: { - testMethod, + testMethod: new TestMethod({}), }, middleware, - container, logger, }); const [outputResult, outputStream] = rpcTestUtils.streamToArray(); diff --git a/tests/clientRPC/handlers/agentStatus.test.ts b/tests/clientRPC/handlers/agentStatus.test.ts index 4b94d381f..8bf1d897a 100644 --- a/tests/clientRPC/handlers/agentStatus.test.ts +++ b/tests/clientRPC/handlers/agentStatus.test.ts @@ -11,7 +11,10 @@ import * as keysUtils from '@/keys/utils'; import RPCServer from '@/RPC/RPCServer'; import TaskManager from '@/tasks/TaskManager'; import CertManager from '@/keys/CertManager'; -import { agentStatus } from '@/clientRPC/handlers/agentStatus'; +import { + agentStatusCaller, + AgentStatusHandler, +} from '@/clientRPC/handlers/agentStatus'; import RPCClient from '@/RPC/RPCClient'; import * as clientRPCUtils from '@/clientRPC/utils'; import * as nodesUtils from '@/nodes/utils'; @@ -78,15 +81,13 @@ describe('agentStatus', () => { }); test('get status %s', async () => { // Setup - const manifest = { - agentStatus, - }; const rpcServer = await RPCServer.createRPCServer({ - manifest, - container: { - keyRing, - certManager, - logger: logger.getChild('container'), + manifest: { + agentStatus: new AgentStatusHandler({ + keyRing, + certManager, + logger: logger.getChild('container'), + }), }, logger: logger.getChild('RPCServer'), }); @@ -96,7 +97,9 @@ describe('agentStatus', () => { logger.getChild('server'), ); const rpcClient = await RPCClient.createRPCClient({ - manifest, + manifest: { + agentStatus: agentStatusCaller, + }, streamPairCreateCallback: async () => { return clientRPCUtils.startConnection( host, diff --git a/tests/clientRPC/handlers/agentUnlock.test.ts b/tests/clientRPC/handlers/agentUnlock.test.ts index 4397df001..0350fc4d4 100644 --- a/tests/clientRPC/handlers/agentUnlock.test.ts +++ b/tests/clientRPC/handlers/agentUnlock.test.ts @@ -11,7 +11,10 @@ import * as keysUtils from '@/keys/utils'; import RPCServer from '@/RPC/RPCServer'; import TaskManager from '@/tasks/TaskManager'; import CertManager from '@/keys/CertManager'; -import { agentUnlock } from '@/clientRPC/handlers/agentUnlock'; +import { + agentUnlockCaller, + AgentUnlockHandler, +} from '@/clientRPC/handlers/agentUnlock'; import RPCClient from '@/RPC/RPCClient'; import { Session, SessionManager } from '@/sessions'; import * as clientRPCUtils from '@/clientRPC/utils'; @@ -90,17 +93,13 @@ describe('agentUnlock', () => { }); test('get status', async () => { // Setup - const manifest = { - agentUnlock, - }; const rpcServer = await RPCServer.createRPCServer({ - manifest, + manifest: { + agentUnlock: new AgentUnlockHandler({ logger }), + }, middleware: rpcUtils.defaultMiddlewareWrapper( clientRPCUtils.authenticationMiddlewareServer(sessionManager, keyRing), ), - container: { - logger, - }, logger, }); wss = clientRPCUtils.createClientServer( @@ -109,7 +108,9 @@ describe('agentUnlock', () => { logger.getChild('server'), ); const rpcClient = await RPCClient.createRPCClient({ - manifest, + manifest: { + agentUnlock: agentUnlockCaller, + }, streamPairCreateCallback: async () => { return clientRPCUtils.startConnection( '127.0.0.1', diff --git a/tests/clientRPC/websocket.test.ts b/tests/clientRPC/websocket.test.ts index 3f3324e7e..ce117cffb 100644 --- a/tests/clientRPC/websocket.test.ts +++ b/tests/clientRPC/websocket.test.ts @@ -1,7 +1,7 @@ import type { TLSConfig } from '@/network/types'; import type { Server } from 'https'; import type { WebSocketServer } from 'ws'; -import type { ManifestItem } from '@/RPC/types'; +import type { ClientManifest } from '@/RPC/types'; import fs from 'fs'; import path from 'path'; import os from 'os'; @@ -11,6 +11,8 @@ import RPCServer from '@/RPC/RPCServer'; import RPCClient from '@/RPC/RPCClient'; import { KeyRing } from '@/keys/index'; import * as clientRPCUtils from '@/clientRPC/utils'; +import { UnaryHandler } from '@/RPC/handlers'; +import { UnaryCaller } from '@/RPC/callers'; import * as testsUtils from '../utils/index'; describe('websocket', () => { @@ -27,7 +29,7 @@ describe('websocket', () => { const host = '127.0.0.1'; let port: number; let rpcServer: RPCServer; - let rpcClient_: RPCClient; + let rpcClient_: RPCClient; beforeEach(async () => { dataDir = await fs.promises.mkdtemp( @@ -57,25 +59,21 @@ describe('websocket', () => { test('websocket should work with RPC', async () => { // Setting up server - const test1: ManifestItem = { - type: 'UNARY', - handler: async (params, _container, _connectionInfo) => { + class Test1 extends UnaryHandler { + public handle = async (params) => { return params; - }, - }; - const test2: ManifestItem = { - type: 'UNARY', - handler: async () => { + }; + } + class Test2 extends UnaryHandler { + public handle = async () => { return { hello: 'not world' }; - }, - }; - const manifest = { - test1, - test2, - }; + }; + } rpcServer = await RPCServer.createRPCServer({ - manifest, - container: {}, + manifest: { + test1: new Test1({}), + test2: new Test2({}), + }, logger: logger.getChild('RPCServer'), }); wss = clientRPCUtils.createClientServer( @@ -86,7 +84,10 @@ describe('websocket', () => { // Setting up client const rpcClient = await RPCClient.createRPCClient({ - manifest, + manifest: { + test1: new UnaryCaller(), + test2: new UnaryCaller(), + }, logger: logger.getChild('RPCClient'), streamPairCreateCallback: async () => { return clientRPCUtils.startConnection( From cad1d070893c1b680d2f3e37d2102de31a1e21a9 Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Fri, 10 Feb 2023 13:22:51 +1100 Subject: [PATCH 34/44] fix: updated metadata type usage [ci skip] --- src/clientRPC/handlers/agentStatus.ts | 9 ++-- src/clientRPC/handlers/agentUnlock.ts | 32 ++++++-------- src/clientRPC/types.ts | 16 ++++--- src/clientRPC/utils.ts | 45 +++++++++++--------- tests/clientRPC/handlers/agentStatus.test.ts | 2 +- tests/clientRPC/handlers/agentUnlock.test.ts | 8 +--- 6 files changed, 55 insertions(+), 57 deletions(-) diff --git a/src/clientRPC/handlers/agentStatus.ts b/src/clientRPC/handlers/agentStatus.ts index 1f84ffffa..34fbc24e7 100644 --- a/src/clientRPC/handlers/agentStatus.ts +++ b/src/clientRPC/handlers/agentStatus.ts @@ -7,6 +7,7 @@ import * as nodesUtils from '../../nodes/utils'; import * as keysUtils from '../../keys/utils'; import { UnaryHandler } from '../../RPC/handlers'; import { UnaryCaller } from '../../RPC/callers'; +import { WithMetadata } from "@/clientRPC/types"; type StatusResult = { pid: number; @@ -14,7 +15,7 @@ type StatusResult = { publicJwk: string; }; -const agentStatusCaller = new UnaryCaller(); +const agentStatusCaller = new UnaryCaller>(); class AgentStatusHandler extends UnaryHandler< { @@ -22,10 +23,10 @@ class AgentStatusHandler extends UnaryHandler< certManager: CertManager; logger: Logger; }, - null, - StatusResult + WithMetadata, + WithMetadata > { - public handle: UnaryHandlerImplementation = async () => { + public handle: UnaryHandlerImplementation> = async () => { return { pid: process.pid, nodeId: nodesUtils.encodeNodeId(this.container.keyRing.getNodeId()), diff --git a/src/clientRPC/handlers/agentUnlock.ts b/src/clientRPC/handlers/agentUnlock.ts index bebd5ba08..b4f7febe9 100644 --- a/src/clientRPC/handlers/agentUnlock.ts +++ b/src/clientRPC/handlers/agentUnlock.ts @@ -1,31 +1,27 @@ import type { UnaryHandlerImplementation } from '../../RPC/types'; import type Logger from '@matrixai/logger'; -import type { ClientDataAndMetadata } from '../types'; +import type { WithMetadata } from '../types'; import { UnaryHandler } from '../../RPC/handlers'; import { UnaryCaller } from '../../RPC/callers'; -const agentUnlockCaller = new UnaryCaller< - ClientDataAndMetadata, - ClientDataAndMetadata ->(); +const agentUnlockCaller = new UnaryCaller(); class AgentUnlockHandler extends UnaryHandler< { logger: Logger }, - ClientDataAndMetadata, - ClientDataAndMetadata + WithMetadata, + WithMetadata > { - public handle: UnaryHandlerImplementation< - ClientDataAndMetadata, - ClientDataAndMetadata - > = async () => { - // This is a NOP handler, - // authentication and unlocking is handled via middleware. - // Failure to authenticate will be an error from the middleware layer. - return { - metadata: {}, - data: null, + public handle: UnaryHandlerImplementation = + async () => { + // This is a NOP handler, + // authentication and unlocking is handled via middleware. + // Failure to authenticate will be an error from the middleware layer. + return { + metadata: { + Authorization: '', + }, + }; }; - }; } export { agentUnlockCaller, AgentUnlockHandler }; diff --git a/src/clientRPC/types.ts b/src/clientRPC/types.ts index b570749f1..8ae04e161 100644 --- a/src/clientRPC/types.ts +++ b/src/clientRPC/types.ts @@ -1,10 +1,12 @@ import type { JSONValue } from '../types'; -type ClientDataAndMetadata = { - metadata: JSONValue & { - Authorization?: string; - }; - data: T; -}; +// eslint-disable-next-line +type NoData = {}; -export type { ClientDataAndMetadata }; +type WithMetadata = NoData> = { + metadata?: { + [Key: string]: JSONValue; + } & Partial<{ Authorization: string }>; +} & Omit; + +export type { WithMetadata, NoData }; diff --git a/src/clientRPC/utils.ts b/src/clientRPC/utils.ts index 2a7d074d0..1e8d189dc 100644 --- a/src/clientRPC/utils.ts +++ b/src/clientRPC/utils.ts @@ -2,8 +2,7 @@ import type { SessionToken } from '../sessions/types'; import type KeyRing from '../keys/KeyRing'; import type SessionManager from '../sessions/SessionManager'; import type { Session } from '../sessions'; -import type { ClientDataAndMetadata } from './types'; -import type { JSONValue } from '../types'; +import type { WithMetadata } from './types'; import type { JsonRpcRequest, JsonRpcResponse, @@ -26,7 +25,7 @@ import { promise } from '../utils'; async function authenticate( sessionManager: SessionManager, keyRing: KeyRing, - message: JsonRpcRequest>, + message: JsonRpcRequest, ) { if (message.params == null) throw new clientErrors.ErrorClientAuthMissing(); if (message.params.metadata == null) { @@ -59,8 +58,8 @@ async function authenticate( return `Bearer ${token}`; } -function decodeAuth(messageParams: ClientDataAndMetadata) { - const auth = messageParams.metadata.Authorization; +function decodeAuth(messageParams: WithMetadata) { + const auth = messageParams.metadata?.Authorization; if (auth == null || !auth.startsWith('Bearer ')) { return; } @@ -76,10 +75,10 @@ function authenticationMiddlewareServer( sessionManager: SessionManager, keyRing: KeyRing, ): MiddlewareFactory< - JsonRpcRequest>, - JsonRpcRequest>, - JsonRpcResponse>, - JsonRpcResponse> + JsonRpcRequest, + JsonRpcRequest, + JsonRpcResponse, + JsonRpcResponse > { return () => { let forwardFirst = true; @@ -87,8 +86,8 @@ function authenticationMiddlewareServer( let outgoingToken: string | null = null; return { forward: new TransformStream< - JsonRpcRequest>, - JsonRpcRequest> + JsonRpcRequest, + JsonRpcRequest >({ transform: async (chunk, controller) => { if (forwardFirst) { @@ -115,6 +114,9 @@ function authenticationMiddlewareServer( transform: (chunk, controller) => { // Add the outgoing metadata to the next message. if (outgoingToken != null && 'result' in chunk) { + if (chunk.result.metadata == null) chunk.result.metadata = { + Authorization: '', + } chunk.result.metadata.Authorization = outgoingToken; outgoingToken = null; } @@ -128,24 +130,27 @@ function authenticationMiddlewareServer( function authenticationMiddlewareClient( session: Session, ): MiddlewareFactory< - JsonRpcRequest>, - JsonRpcRequest>, - JsonRpcResponse>, - JsonRpcResponse> + JsonRpcRequest, + JsonRpcRequest, + JsonRpcResponse, + JsonRpcResponse > { return () => { let forwardFirst = true; return { forward: new TransformStream< - JsonRpcRequest>, - JsonRpcRequest> + JsonRpcRequest, + JsonRpcRequest >({ transform: async (chunk, controller) => { if (forwardFirst) { if (chunk.params == null) utils.never(); - if (chunk.params.metadata.Authorization == null) { + if (chunk.params.metadata?.Authorization == null) { const token = await session.readToken(); if (token != null) { + if (chunk.params.metadata == null) chunk.params.metadata = { + Authorization: '', + } chunk.params.metadata.Authorization = `Bearer ${token}`; } } @@ -155,8 +160,8 @@ function authenticationMiddlewareClient( }, }), reverse: new TransformStream< - JsonRpcResponse>, - JsonRpcResponse> + JsonRpcResponse, + JsonRpcResponse >({ transform: async (chunk, controller) => { controller.enqueue(chunk); diff --git a/tests/clientRPC/handlers/agentStatus.test.ts b/tests/clientRPC/handlers/agentStatus.test.ts index 8bf1d897a..d1278b387 100644 --- a/tests/clientRPC/handlers/agentStatus.test.ts +++ b/tests/clientRPC/handlers/agentStatus.test.ts @@ -110,7 +110,7 @@ describe('agentStatus', () => { logger: logger.getChild('RPCClient'), }); // Doing the test - const result = await rpcClient.methods.agentStatus(null); + const result = await rpcClient.methods.agentStatus({}); expect(result).toStrictEqual({ pid: process.pid, nodeId: nodesUtils.encodeNodeId(keyRing.getNodeId()), diff --git a/tests/clientRPC/handlers/agentUnlock.test.ts b/tests/clientRPC/handlers/agentUnlock.test.ts index 0350fc4d4..1a533a070 100644 --- a/tests/clientRPC/handlers/agentUnlock.test.ts +++ b/tests/clientRPC/handlers/agentUnlock.test.ts @@ -129,23 +129,17 @@ describe('agentUnlock', () => { metadata: { Authorization: clientRPCUtils.encodeAuthFromPassword(password), }, - data: null, }); expect(result).toMatchObject({ metadata: { Authorization: expect.any(String), }, - data: null, - }); - const result2 = await rpcClient.methods.agentUnlock({ - metadata: {}, - data: null, }); + const result2 = await rpcClient.methods.agentUnlock({}); expect(result2).toMatchObject({ metadata: { Authorization: expect.any(String), }, - data: null, }); }); }); From b0cfc7a743dc3d72c71229f06c7668af1c9aa8e0 Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Fri, 10 Feb 2023 16:18:30 +1100 Subject: [PATCH 35/44] feat: client methods refactor [ci skip] --- src/RPC/RPCClient.ts | 241 ++++++++--------- src/RPC/types.ts | 98 ++----- src/clientRPC/handlers/agentStatus.ts | 12 +- src/clientRPC/utils.ts | 12 +- tests/RPC/RPC.test.ts | 22 +- tests/RPC/RPCClient.test.ts | 362 ++++---------------------- 6 files changed, 211 insertions(+), 536 deletions(-) diff --git a/src/RPC/RPCClient.ts b/src/RPC/RPCClient.ts index ee4846b73..50545ec3b 100644 --- a/src/RPC/RPCClient.ts +++ b/src/RPC/RPCClient.ts @@ -3,14 +3,10 @@ import type { JsonRpcRequestMessage, StreamPairCreateCallback, ClientManifest, - MapWithCallers, + MapRawCallers, } from './types'; import type { JSONValue } from 'types'; -import type { - ReadableWritablePair, - ReadableStream, - WritableStream, -} from 'stream/web'; +import type { ReadableWritablePair, WritableStream } from 'stream/web'; import type { JsonRpcRequest, JsonRpcResponse, @@ -21,6 +17,7 @@ import { CreateDestroy, ready } from '@matrixai/async-init/dist/CreateDestroy'; import Logger from '@matrixai/logger'; import * as rpcErrors from './errors'; import * as rpcUtils from './utils'; +import { never } from '../utils'; // eslint-disable-next-line interface RPCClient extends CreateDestroy {} @@ -53,38 +50,35 @@ class RPCClient { {}, { get: (_, method) => { - if (typeof method === 'symbol') throw Error('invalid symbol'); + if (typeof method === 'symbol') throw never(); switch (this.callerTypes[method]) { - case 'DUPLEX': - return () => this.duplexStreamCaller(method); + case 'UNARY': + return (params) => this.unaryCaller(method, params); case 'SERVER': return (params) => this.serverStreamCaller(method, params); case 'CLIENT': - return () => this.clientStreamCaller(method); - case 'UNARY': - return (params) => this.unaryCaller(method, params); + return (f) => this.clientStreamCaller(method, f); + case 'DUPLEX': + return (f) => this.duplexStreamCaller(method, f); case 'RAW': - return (params) => this.rawStreamCaller(method, params); default: return; } }, }, ); - protected withMethodsProxy = new Proxy( + protected rawMethodsProxy = new Proxy( {}, { get: (_, method) => { - if (typeof method === 'symbol') throw Error('invalid symbol'); + if (typeof method === 'symbol') throw never(); switch (this.callerTypes[method]) { case 'DUPLEX': - return (f) => this.withDuplexCaller(method, f); + return () => this.rawDuplexStreamCaller(method); + case 'RAW': + return (params) => this.rawStreamCaller(method, params); case 'SERVER': - return (params, f) => this.withServerCaller(method, params, f); case 'CLIENT': - return (f) => this.withClientCaller(method, f); - case 'RAW': - return (params, f) => this.withRawStreamCaller(method, params, f); case 'UNARY': default: return; @@ -118,31 +112,92 @@ class RPCClient { } @ready(new rpcErrors.ErrorRpcDestroyed()) - public get withMethods(): MapWithCallers { - return this.withMethodsProxy as MapWithCallers; + public get rawMethods(): MapRawCallers { + return this.rawMethodsProxy as MapRawCallers; } + // Convenience methods + @ready(new rpcErrors.ErrorRpcDestroyed()) - public async rawStreamCaller( + public async unaryCaller( method: string, - params: JSONValue, - ): Promise> { - const streamPair = await this.streamPairCreateCallback(); - const tempWriter = streamPair.writable.getWriter(); - const header: JsonRpcRequestMessage = { - jsonrpc: '2.0', - method, - params, - id: null, + parameters: I, + ): Promise { + const callerInterface = await this.rawDuplexStreamCaller(method); + const reader = callerInterface.readable.getReader(); + const writer = callerInterface.writable.getWriter(); + await writer.write(parameters); + const output = await reader.read(); + if (output.done) { + throw new rpcErrors.ErrorRpcRemoteError('Stream ended before response'); + } + await reader.cancel(); + await writer.close(); + return output.value; + } + + @ready(new rpcErrors.ErrorRpcDestroyed()) + public async serverStreamCaller( + method: string, + parameters: I, + ): Promise> { + const callerInterface = await this.rawDuplexStreamCaller(method); + const writer = callerInterface.writable.getWriter(); + await writer.write(parameters); + await writer.close(); + + const outputGen = async function* () { + for await (const value of callerInterface.readable) { + yield value; + } }; - await tempWriter.write(Buffer.from(JSON.stringify(header))); - tempWriter.releaseLock(); - return streamPair; + return outputGen(); + } + + @ready(new rpcErrors.ErrorRpcDestroyed()) + public async clientStreamCaller( + method: string, + f: (output: Promise) => AsyncGenerator, + ): Promise { + const callerInterface = await this.rawClientStreamCaller(method); + const writer = callerInterface.writable.getWriter(); + let running = true; + for await (const value of f(callerInterface.output)) { + if (value === undefined) { + await writer.close(); + running = false; + } + // Write while running otherwise consume until ended + if (running) await writer.write(value); + } + // If ended before finish running then close writer + if (running) await writer.close(); } @ready(new rpcErrors.ErrorRpcDestroyed()) public async duplexStreamCaller( method: string, + f: (output: AsyncGenerator) => AsyncGenerator, + ): Promise { + const callerInterface = await this.rawDuplexStreamCaller(method); + const outputGenerator = async function* () { + for await (const value of callerInterface.readable) { + yield value; + } + }; + const writer = callerInterface.writable.getWriter(); + try { + for await (const value of f(outputGenerator())) { + await writer.write(value); + } + } finally { + await writer.close(); + } + } + + @ready(new rpcErrors.ErrorRpcDestroyed()) + public async rawDuplexStreamCaller( + method: string, ): Promise> { // Creating caller side transforms const outputMessageTransforStream = @@ -179,26 +234,33 @@ class RPCClient { } @ready(new rpcErrors.ErrorRpcDestroyed()) - public async serverStreamCaller( + public async rawStreamCaller( method: string, - parameters: I, - ): Promise> { - const callerInterface = await this.duplexStreamCaller(method); - const writer = callerInterface.writable.getWriter(); - await writer.write(parameters); - await writer.close(); - - return callerInterface.readable; + params: JSONValue, + ): Promise> { + const streamPair = await this.streamPairCreateCallback(); + const tempWriter = streamPair.writable.getWriter(); + const header: JsonRpcRequestMessage = { + jsonrpc: '2.0', + method, + params, + id: null, + }; + await tempWriter.write(Buffer.from(JSON.stringify(header))); + tempWriter.releaseLock(); + return streamPair; } - @ready(new rpcErrors.ErrorRpcDestroyed()) - public async clientStreamCaller( + protected async rawClientStreamCaller< + I extends JSONValue, + O extends JSONValue, + >( method: string, ): Promise<{ output: Promise; writable: WritableStream; }> { - const callerInterface = await this.duplexStreamCaller(method); + const callerInterface = await this.rawDuplexStreamCaller(method); const reader = callerInterface.readable.getReader(); const output = reader.read().then(({ value, done }) => { if (done) { @@ -212,91 +274,6 @@ class RPCClient { }; } - @ready(new rpcErrors.ErrorRpcDestroyed()) - public async unaryCaller( - method: string, - parameters: I, - ): Promise { - const callerInterface = await this.duplexStreamCaller(method); - const reader = callerInterface.readable.getReader(); - const writer = callerInterface.writable.getWriter(); - await writer.write(parameters); - const output = await reader.read(); - if (output.done) { - throw new rpcErrors.ErrorRpcRemoteError('Stream ended before response'); - } - await reader.cancel(); - await writer.close(); - return output.value; - } - - @ready(new rpcErrors.ErrorRpcDestroyed()) - public async withRawStreamCaller( - method: string, - params: JSONValue, - f: (output: AsyncGenerator) => AsyncGenerator, - ): Promise { - const callerInterface = await this.rawStreamCaller(method, params); - const outputGenerator = async function* () { - for await (const value of callerInterface.readable) { - yield value; - } - }; - const writer = callerInterface.writable.getWriter(); - for await (const value of f(outputGenerator())) { - await writer.write(value); - } - await writer.close(); - } - - @ready(new rpcErrors.ErrorRpcDestroyed()) - public async withDuplexCaller( - method: string, - f: (output: AsyncGenerator) => AsyncGenerator, - ): Promise { - const callerInterface = await this.duplexStreamCaller(method); - const outputGenerator = async function* () { - for await (const value of callerInterface.readable) { - yield value; - } - }; - const writer = callerInterface.writable.getWriter(); - for await (const value of f(outputGenerator())) { - await writer.write(value); - } - await writer.close(); - } - - @ready(new rpcErrors.ErrorRpcDestroyed()) - public async withServerCaller( - method: string, - parameters: I, - f: (output: AsyncGenerator) => Promise, - ): Promise { - const callerInterface = await this.serverStreamCaller( - method, - parameters, - ); - const outputGenerator = async function* () { - yield* callerInterface; - }; - await f(outputGenerator()); - } - - @ready(new rpcErrors.ErrorRpcDestroyed()) - public async withClientCaller( - method: string, - f: () => AsyncGenerator, - ): Promise { - const callerInterface = await this.clientStreamCaller(method); - const writer = callerInterface.writable.getWriter(); - for await (const value of f()) { - await writer.write(value); - } - await writer.close(); - return callerInterface.output; - } - protected middleware: Array< MiddlewareFactory< JsonRpcRequest, diff --git a/src/RPC/types.ts b/src/RPC/types.ts index f4e2a03ff..336a52663 100644 --- a/src/RPC/types.ts +++ b/src/RPC/types.ts @@ -1,11 +1,7 @@ import type { JSONValue } from '../types'; import type { ConnectionInfo } from '../network/types'; import type { ContextCancellable } from '../contexts/types'; -import type { - ReadableStream, - ReadableWritablePair, - WritableStream, -} from 'stream/web'; +import type { ReadableStream, ReadableWritablePair } from 'stream/web'; import type { Handler } from './handlers'; import type { Caller, @@ -150,32 +146,38 @@ type MiddlewareFactory = (header?: JsonRpcRequest) => { reverse: ReadableWritablePair; }; -type RawCallerImplementation = ( - params: JSONValue, -) => Promise>; +// Convenience callers -type DuplexCallerImplementation< +type UnaryCallerImplementation< I extends JSONValue = JSONValue, O extends JSONValue = JSONValue, -> = () => Promise>; +> = (parameters: I) => Promise; type ServerCallerImplementation< I extends JSONValue = JSONValue, O extends JSONValue = JSONValue, -> = (parameters: I) => Promise>; +> = (parameters: I) => Promise>; type ClientCallerImplementation< I extends JSONValue = JSONValue, O extends JSONValue = JSONValue, -> = () => Promise<{ - output: Promise; - writable: WritableStream; -}>; +> = (f: (output: Promise) => AsyncGenerator) => Promise; -type UnaryCallerImplementation< +type DuplexCallerImplementation< I extends JSONValue = JSONValue, O extends JSONValue = JSONValue, -> = (parameters: I) => Promise; +> = (f: (output: AsyncGenerator) => AsyncGenerator) => Promise; + +// Raw callers + +type RawDuplexCallerImplementation< + I extends JSONValue = JSONValue, + O extends JSONValue = JSONValue, +> = () => Promise>; + +type RawCallerImplementation = ( + params: JSONValue, +) => Promise>; type ConvertDuplexCaller = T extends DuplexCaller ? DuplexCallerImplementation @@ -201,62 +203,16 @@ type ConvertCaller = T extends DuplexCaller ? ConvertClientCaller : T extends UnaryCaller ? ConvertUnaryCaller - : T extends RawCaller - ? RawCallerImplementation - : never; - -type WithDuplexCallerImplementation< - I extends JSONValue = JSONValue, - O extends JSONValue = JSONValue, -> = (f: (output: AsyncGenerator) => AsyncGenerator) => Promise; - -type WithServerCallerImplementation< - I extends JSONValue = JSONValue, - O extends JSONValue = JSONValue, -> = ( - parameters: I, - f: (output: AsyncGenerator) => Promise, -) => Promise; - -type WithClientCallerImplementation< - I extends JSONValue = JSONValue, - O extends JSONValue = JSONValue, -> = (f: () => AsyncGenerator) => Promise; - -type WithRawCallerImplementation = ( - params: JSONValue, - f: (output: AsyncGenerator) => AsyncGenerator, -) => Promise; - -type ConvertWithDuplexStreamHandler = T extends DuplexCaller< - infer I, - infer O -> - ? WithDuplexCallerImplementation : never; -type ConvertWithServerStreamHandler = T extends ServerCaller< - infer I, - infer O -> - ? WithServerCallerImplementation +type ConvertRawDuplexStreamHandler = T extends DuplexCaller + ? RawDuplexCallerImplementation : never; -type ConvertWithClientStreamHandler = T extends ClientCaller< - infer I, - infer O -> - ? WithClientCallerImplementation - : never; - -type ConvertWithHandler = T extends DuplexCaller - ? ConvertWithDuplexStreamHandler - : T extends ServerCaller - ? ConvertWithServerStreamHandler - : T extends ClientCaller - ? ConvertWithClientStreamHandler +type ConvertRawCaller = T extends DuplexCaller + ? ConvertRawDuplexStreamHandler : T extends RawCaller - ? WithRawCallerImplementation + ? RawCallerImplementation : never; type ServerManifest = Record; @@ -268,8 +224,8 @@ type MapCallers = { [K in keyof T]: ConvertCaller; }; -type MapWithCallers = { - [K in keyof T]: ConvertWithHandler; +type MapRawCallers = { + [K in keyof T]: ConvertRawCaller; }; export type { @@ -294,5 +250,5 @@ export type { ClientManifest, HandlerType, MapCallers, - MapWithCallers, + MapRawCallers, }; diff --git a/src/clientRPC/handlers/agentStatus.ts b/src/clientRPC/handlers/agentStatus.ts index 34fbc24e7..9bfb0a7d7 100644 --- a/src/clientRPC/handlers/agentStatus.ts +++ b/src/clientRPC/handlers/agentStatus.ts @@ -3,11 +3,11 @@ import type KeyRing from '../../keys/KeyRing'; import type CertManager from '../../keys/CertManager'; import type Logger from '@matrixai/logger'; import type { NodeIdEncoded } from '../../ids'; +import type { WithMetadata } from '../types'; import * as nodesUtils from '../../nodes/utils'; import * as keysUtils from '../../keys/utils'; import { UnaryHandler } from '../../RPC/handlers'; import { UnaryCaller } from '../../RPC/callers'; -import { WithMetadata } from "@/clientRPC/types"; type StatusResult = { pid: number; @@ -15,7 +15,10 @@ type StatusResult = { publicJwk: string; }; -const agentStatusCaller = new UnaryCaller>(); +const agentStatusCaller = new UnaryCaller< + WithMetadata, + WithMetadata +>(); class AgentStatusHandler extends UnaryHandler< { @@ -26,7 +29,10 @@ class AgentStatusHandler extends UnaryHandler< WithMetadata, WithMetadata > { - public handle: UnaryHandlerImplementation> = async () => { + public handle: UnaryHandlerImplementation< + WithMetadata, + WithMetadata + > = async () => { return { pid: process.pid, nodeId: nodesUtils.encodeNodeId(this.container.keyRing.getNodeId()), diff --git a/src/clientRPC/utils.ts b/src/clientRPC/utils.ts index 1e8d189dc..ccd5f8b74 100644 --- a/src/clientRPC/utils.ts +++ b/src/clientRPC/utils.ts @@ -114,8 +114,10 @@ function authenticationMiddlewareServer( transform: (chunk, controller) => { // Add the outgoing metadata to the next message. if (outgoingToken != null && 'result' in chunk) { - if (chunk.result.metadata == null) chunk.result.metadata = { - Authorization: '', + if (chunk.result.metadata == null) { + chunk.result.metadata = { + Authorization: '', + }; } chunk.result.metadata.Authorization = outgoingToken; outgoingToken = null; @@ -148,8 +150,10 @@ function authenticationMiddlewareClient( if (chunk.params.metadata?.Authorization == null) { const token = await session.readToken(); if (token != null) { - if (chunk.params.metadata == null) chunk.params.metadata = { - Authorization: '', + if (chunk.params.metadata == null) { + chunk.params.metadata = { + Authorization: '', + }; } chunk.params.metadata.Authorization = `Bearer ${token}`; } diff --git a/tests/RPC/RPC.test.ts b/tests/RPC/RPC.test.ts index 0dc9e8c37..2793e24de 100644 --- a/tests/RPC/RPC.test.ts +++ b/tests/RPC/RPC.test.ts @@ -65,7 +65,7 @@ describe('RPC', () => { logger, }); - const callerInterface = await rpcClient.methods.testMethod({ + const callerInterface = await rpcClient.rawMethods.testMethod({ hello: 'world', }); const writer = callerInterface.writable.getWriter(); @@ -118,7 +118,7 @@ describe('RPC', () => { logger, }); - const callerInterface = await rpcClient.methods.testMethod(); + const callerInterface = await rpcClient.rawMethods.testMethod(); const writer = callerInterface.writable.getWriter(); const reader = callerInterface.readable.getReader(); for (const value of values) { @@ -214,15 +214,17 @@ describe('RPC', () => { logger, }); - const callerInterface = await rpcClient.methods.testMethod(); - const writer = callerInterface.writable.getWriter(); - for (const value of values) { - await writer.write(value); - } - await writer.close(); + await rpcClient.methods.testMethod(async function* (output) { + for (const value of values) { + yield value; + } + // Ending writes + yield undefined; + // Checking output + const expectedResult = values.reduce((p, c) => p + c); + await expect(output).resolves.toEqual(expectedResult); + }); - const expectedResult = values.reduce((p, c) => p + c); - await expect(callerInterface.output).resolves.toEqual(expectedResult); await rpcServer.destroy(); await rpcClient.destroy(); }, diff --git a/tests/RPC/RPCClient.test.ts b/tests/RPC/RPCClient.test.ts index 3af073353..ae2fda61e 100644 --- a/tests/RPC/RPCClient.test.ts +++ b/tests/RPC/RPCClient.test.ts @@ -33,7 +33,7 @@ describe(`${RPCClient.name}`, () => { .noShrink(); testProp( - 'raw duplex caller', + 'raw caller', [ rpcTestUtils.safeJsonValueArb, rpcTestUtils.rawDataArb, @@ -97,21 +97,13 @@ describe(`${RPCClient.name}`, () => { streamPairCreateCallback: async () => streamPair, logger, }); - const callerInterface = await rpcClient.duplexStreamCaller< - JSONValue, - JSONValue - >(methodName); - const reader = callerInterface.readable.getReader(); - const writer = callerInterface.writable.getWriter(); - while (true) { - const { value, done } = await reader.read(); - if (done) { - // We have to end the writer otherwise the stream never closes - await writer.close(); - break; - } - await writer.write(value); - } + await rpcClient.duplexStreamCaller( + methodName, + async function* (output) { + yield* output; + }, + ); + const expectedMessages: Array = messages.map((v) => { const request: JsonRpcRequestMessage = { jsonrpc: '2.0', @@ -182,16 +174,16 @@ describe(`${RPCClient.name}`, () => { streamPairCreateCallback: async () => streamPair, logger, }); - const callerInterface = await rpcClient.clientStreamCaller< - JSONValue, - JSONValue - >(methodName); - const writer = callerInterface.writable.getWriter(); - for (const param of params) { - await writer.write(param as JSONValue); - } - await writer.close(); - expect(await callerInterface.output).toStrictEqual(message.result); + await rpcClient.clientStreamCaller( + methodName, + async function* (output) { + for (const param of params) { + yield param; + } + yield undefined; + expect(await output).toStrictEqual(message.result); + }, + ); const expectedOutput = params.map((v) => JSON.stringify({ method: methodName, @@ -259,77 +251,21 @@ describe(`${RPCClient.name}`, () => { streamPairCreateCallback: async () => streamPair, logger, }); - const callerInterface = await rpcClient.duplexStreamCaller< - JSONValue, - JSONValue - >(methodName); - const consumeToError = async () => { - for await (const _ of callerInterface.readable) { - // No touch, just consume - } - }; - await expect(consumeToError()).rejects.toThrow( - rpcErrors.ErrorRpcRemoteError, - ); - await callerInterface.writable.close(); - await outputResult; - await rpcClient.destroy(); - }, - ); - testProp( - 'withRawStreamCaller', - [ - rpcTestUtils.safeJsonValueArb, - rpcTestUtils.rawDataArb, - rpcTestUtils.rawDataArb, - ], - async (headerParams, inputData, outputData) => { - const [inputResult, inputWritableStream] = - rpcTestUtils.streamToArray(); - const streamPair: ReadableWritablePair = { - readable: new ReadableStream({ - start: (controller) => { - for (const datum of outputData) { - controller.enqueue(datum); - } - controller.close(); - }, - }), - writable: inputWritableStream, - }; - const rpcClient = await RPCClient.createRPCClient({ - manifest: {}, - streamPairCreateCallback: async () => streamPair, - logger, - }); - const outputResult: Array = []; - await rpcClient.withRawStreamCaller( + const callProm = rpcClient.duplexStreamCaller( methodName, - headerParams, async function* (output) { - for await (const outputValue of output) { - outputResult.push(outputValue); - } - for (const inputDatum of inputData) { - yield inputDatum; + for await (const _ of output) { + // No touch, just consume } }, ); - const expectedHeader: JsonRpcRequest = { - jsonrpc: '2.0', - method: methodName, - params: headerParams, - id: null, - }; - expect(await inputResult).toStrictEqual([ - Buffer.from(JSON.stringify(expectedHeader)), - ...inputData, - ]); - expect(outputResult).toStrictEqual(outputData); + await expect(callProm).rejects.toThrow(rpcErrors.ErrorRpcRemoteError); + await outputResult; + await rpcClient.destroy(); }, ); testProp( - 'withDuplexCaller', + 'rawDuplexStreamCaller', [fc.array(rpcTestUtils.jsonRpcResponseResultArb(), { minLength: 1 })], async (messages) => { const inputStream = rpcTestUtils.messagesToReadableStream(messages); @@ -345,12 +281,13 @@ describe(`${RPCClient.name}`, () => { logger, }); let count = 0; - await rpcClient.withDuplexCaller(methodName, async function* (output) { - for await (const value of output) { - count += 1; - yield value; - } - }); + const callerInterface = await rpcClient.rawDuplexStreamCaller(methodName); + const writer = callerInterface.writable.getWriter(); + for await (const val of callerInterface.readable) { + count += 1; + await writer.write(val); + } + await writer.close(); const result = await outputResult; // We're just checking that it's consuming the messages as expected expect(result.length).toEqual(messages.length); @@ -358,84 +295,6 @@ describe(`${RPCClient.name}`, () => { await rpcClient.destroy(); }, ); - testProp( - 'withServerCaller', - [ - fc.array(rpcTestUtils.jsonRpcResponseResultArb(), { minLength: 1 }), - rpcTestUtils.safeJsonValueArb, - ], - async (messages, params) => { - const inputStream = rpcTestUtils.messagesToReadableStream(messages); - const [outputResult, outputStream] = - rpcTestUtils.streamToArray(); - const streamPair: ReadableWritablePair = { - readable: inputStream, - writable: outputStream, - }; - const rpcClient = await RPCClient.createRPCClient({ - manifest: {}, - streamPairCreateCallback: async () => streamPair, - logger, - }); - let count = 0; - await rpcClient.withServerCaller(methodName, params, async (output) => { - for await (const _ of output) count += 1; - }); - const result = await outputResult; - expect(count).toEqual(messages.length); - expect(result.toString()).toStrictEqual( - JSON.stringify({ - method: methodName, - jsonrpc: '2.0', - id: null, - params: params, - }), - ); - await rpcClient.destroy(); - }, - ); - testProp( - 'withClientCaller', - [ - rpcTestUtils.jsonRpcResponseResultArb(), - fc.array(rpcTestUtils.safeJsonValueArb, { minLength: 2 }).noShrink(), - ], - async (message, inputMessages) => { - const inputStream = rpcTestUtils.messagesToReadableStream([message]); - const [outputResult, outputStream] = - rpcTestUtils.streamToArray(); - const streamPair: ReadableWritablePair = { - readable: inputStream, - writable: outputStream, - }; - const rpcClient = await RPCClient.createRPCClient({ - manifest: {}, - streamPairCreateCallback: async () => streamPair, - logger, - }); - const result = await rpcClient.withClientCaller( - methodName, - async function* () { - for (const inputMessage of inputMessages) { - yield inputMessage; - } - }, - ); - const expectedResult = inputMessages.map((v) => { - return JSON.stringify({ - method: methodName, - jsonrpc: '2.0', - id: null, - params: v, - }); - }); - expect((await outputResult).map((v) => v.toString())).toStrictEqual( - expectedResult, - ); - expect(result).toStrictEqual(message.result); - await rpcClient.destroy(); - }, - ); testProp( 'generic duplex caller with forward Middleware', [specificMessageArb], @@ -466,7 +325,7 @@ describe(`${RPCClient.name}`, () => { reverse: new TransformStream(), }; }); - const callerInterface = await rpcClient.duplexStreamCaller< + const callerInterface = await rpcClient.rawDuplexStreamCaller< JSONValue, JSONValue >(methodName); @@ -530,7 +389,7 @@ describe(`${RPCClient.name}`, () => { }), }; }); - const callerInterface = await rpcClient.duplexStreamCaller< + const callerInterface = await rpcClient.rawDuplexStreamCaller< JSONValue, JSONValue >(methodName); @@ -551,7 +410,7 @@ describe(`${RPCClient.name}`, () => { }, ); testProp( - 'manifest duplex call', + 'manifest raw duplex call', [ fc.array(rpcTestUtils.jsonRpcResponseResultArb(fc.string()), { minLength: 5, @@ -571,7 +430,7 @@ describe(`${RPCClient.name}`, () => { streamPairCreateCallback: async () => streamPair, logger, }); - const callerInterface = await rpcClient.methods.duplex(); + const callerInterface = await rpcClient.rawMethods.duplex(); const reader = callerInterface.readable.getReader(); const writer = callerInterface.writable.getWriter(); while (true) { @@ -658,13 +517,13 @@ describe(`${RPCClient.name}`, () => { streamPairCreateCallback: async () => streamPair, logger, }); - const callerInterface = await rpcClient.methods.client(); - const writer = callerInterface.writable.getWriter(); - for (const param of params) { - await writer.write(param); - } - await writer.close(); - expect(await callerInterface.output).toStrictEqual(message.result); + await rpcClient.methods.client(async function* (output) { + for (const param of params) { + yield param; + } + yield undefined; + expect(await output).toStrictEqual(message.result); + }); const expectedOutput = params.map((v) => JSON.stringify({ method: 'client', @@ -739,7 +598,7 @@ describe(`${RPCClient.name}`, () => { streamPairCreateCallback: async () => streamPair, logger, }); - const callerInterface = await rpcClient.methods.raw(headerParams); + const callerInterface = await rpcClient.rawMethods.raw(headerParams); await callerInterface.readable.pipeTo(outputWritableStream); const writer = callerInterface.writable.getWriter(); for (const inputDatum of inputData) { @@ -761,7 +620,7 @@ describe(`${RPCClient.name}`, () => { }, ); testProp( - 'manifest withDuplex caller', + 'manifest duplex caller', [ fc.array(rpcTestUtils.jsonRpcResponseResultArb(fc.string()), { minLength: 1, @@ -783,7 +642,7 @@ describe(`${RPCClient.name}`, () => { logger, }); let count = 0; - await rpcClient.withMethods.duplex(async function* (output) { + await rpcClient.methods.duplex(async function* (output) { for await (const value of output) { count += 1; yield value; @@ -796,135 +655,6 @@ describe(`${RPCClient.name}`, () => { await rpcClient.destroy(); }, ); - testProp( - 'manifest withServer caller', - [ - fc.array(rpcTestUtils.jsonRpcResponseResultArb(), { minLength: 1 }), - fc.string(), - ], - async (messages, params) => { - const inputStream = rpcTestUtils.messagesToReadableStream(messages); - const [outputResult, outputStream] = - rpcTestUtils.streamToArray(); - const streamPair: ReadableWritablePair = { - readable: inputStream, - writable: outputStream, - }; - const rpcClient = await RPCClient.createRPCClient({ - manifest: { - server: new ServerCaller(), - }, - streamPairCreateCallback: async () => streamPair, - logger, - }); - let count = 0; - await rpcClient.withMethods.server(params, async (output) => { - for await (const _ of output) count += 1; - }); - const result = await outputResult; - expect(count).toEqual(messages.length); - expect(result.toString()).toStrictEqual( - JSON.stringify({ - method: 'server', - jsonrpc: '2.0', - id: null, - params: params, - }), - ); - await rpcClient.destroy(); - }, - ); - testProp( - 'manifest withClient caller', - [ - rpcTestUtils.jsonRpcResponseResultArb(), - fc.array(fc.string(), { minLength: 2 }).noShrink(), - ], - async (message, inputMessages) => { - const inputStream = rpcTestUtils.messagesToReadableStream([message]); - const [outputResult, outputStream] = - rpcTestUtils.streamToArray(); - const streamPair: ReadableWritablePair = { - readable: inputStream, - writable: outputStream, - }; - const rpcClient = await RPCClient.createRPCClient({ - manifest: { - client: new ClientCaller(), - }, - streamPairCreateCallback: async () => streamPair, - logger, - }); - const result = await rpcClient.withMethods.client(async function* () { - for (const inputMessage of inputMessages) { - yield inputMessage; - } - }); - const expectedResult = inputMessages.map((v) => { - return JSON.stringify({ - method: 'client', - jsonrpc: '2.0', - id: null, - params: v, - }); - }); - expect((await outputResult).map((v) => v.toString())).toStrictEqual( - expectedResult, - ); - expect(result).toStrictEqual(message.result); - await rpcClient.destroy(); - }, - ); - testProp( - 'manifest withRaw caller', - [ - rpcTestUtils.safeJsonValueArb, - rpcTestUtils.rawDataArb, - rpcTestUtils.rawDataArb, - ], - async (headerParams, inputData, outputData) => { - const [inputResult, inputWritableStream] = - rpcTestUtils.streamToArray(); - const streamPair: ReadableWritablePair = { - readable: new ReadableStream({ - start: (controller) => { - for (const datum of outputData) { - controller.enqueue(datum); - } - controller.close(); - }, - }), - writable: inputWritableStream, - }; - const rpcClient = await RPCClient.createRPCClient({ - manifest: { - raw: new RawCaller(), - }, - streamPairCreateCallback: async () => streamPair, - logger, - }); - const outputResult: Array = []; - await rpcClient.withMethods.raw(headerParams, async function* (output) { - for await (const outputValue of output) { - outputResult.push(outputValue); - } - for (const inputDatum of inputData) { - yield inputDatum; - } - }); - const expectedHeader: JsonRpcRequest = { - jsonrpc: '2.0', - method: 'raw', - params: headerParams, - id: null, - }; - expect(await inputResult).toStrictEqual([ - Buffer.from(JSON.stringify(expectedHeader)), - ...inputData, - ]); - expect(outputResult).toStrictEqual(outputData); - }, - ); test('manifest without handler errors', async () => { const rpcClient = await RPCClient.createRPCClient({ manifest: {}, From 534038fc07b7de177d30543265e07e69d85f4acd Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Fri, 10 Feb 2023 17:37:43 +1100 Subject: [PATCH 36/44] feat: client static middleware registration Related #501 Related #502 [ci skip] --- src/RPC/RPCClient.ts | 83 +++++++++----------- src/RPC/RPCServer.ts | 2 +- src/RPC/utils.ts | 49 +++++++++++- tests/RPC/RPCClient.test.ts | 53 +++++++------ tests/RPC/RPCServer.test.ts | 6 +- tests/clientRPC/handlers/agentUnlock.test.ts | 8 +- 6 files changed, 117 insertions(+), 84 deletions(-) diff --git a/src/RPC/RPCClient.ts b/src/RPC/RPCClient.ts index 50545ec3b..e75d1b54b 100644 --- a/src/RPC/RPCClient.ts +++ b/src/RPC/RPCClient.ts @@ -17,6 +17,10 @@ import { CreateDestroy, ready } from '@matrixai/async-init/dist/CreateDestroy'; import Logger from '@matrixai/logger'; import * as rpcErrors from './errors'; import * as rpcUtils from './utils'; +import { + clientInputTransformStream, + clientOutputTransformStream, +} from './utils'; import { never } from '../utils'; // eslint-disable-next-line @@ -26,16 +30,24 @@ class RPCClient { static async createRPCClient({ manifest, streamPairCreateCallback, + middleware = rpcUtils.defaultClientMiddlewareWrapper(), logger = new Logger(this.name), }: { manifest: M; streamPairCreateCallback: StreamPairCreateCallback; - logger: Logger; + middleware?: MiddlewareFactory< + Uint8Array, + JsonRpcRequest, + JsonRpcResponse, + Uint8Array + >; + logger?: Logger; }) { logger.info(`Creating ${this.name}`); const rpcClient = new this({ manifest, streamPairCreateCallback, + middleware, logger, }); logger.info(`Created ${this.name}`); @@ -44,6 +56,12 @@ class RPCClient { protected logger: Logger; protected streamPairCreateCallback: StreamPairCreateCallback; + protected middleware: MiddlewareFactory< + Uint8Array, + JsonRpcRequest, + JsonRpcResponse, + Uint8Array + >; protected callerTypes: Record; // Method proxies public readonly methodsProxy = new Proxy( @@ -90,14 +108,22 @@ class RPCClient { public constructor({ manifest, streamPairCreateCallback, + middleware, logger, }: { manifest: M; streamPairCreateCallback: StreamPairCreateCallback; + middleware: MiddlewareFactory< + Uint8Array, + JsonRpcRequest, + JsonRpcResponse, + Uint8Array + >; logger: Logger; }) { this.callerTypes = rpcUtils.getHandlerTypes(manifest); this.streamPairCreateCallback = streamPairCreateCallback; + this.middleware = middleware; this.logger = logger; } @@ -199,36 +225,23 @@ class RPCClient { public async rawDuplexStreamCaller( method: string, ): Promise> { - // Creating caller side transforms - const outputMessageTransforStream = - rpcUtils.clientOutputTransformStream(); - const inputMessageTransformStream = - rpcUtils.clientInputTransformStream(method); - let reverseStream = outputMessageTransforStream.writable; - let forwardStream = inputMessageTransformStream.readable; - // Setting up middleware chains - for (const middlewareFactory of this.middleware) { - const middleware = middlewareFactory(); - forwardStream = forwardStream.pipeThrough(middleware.forward); - void middleware.reverse.readable.pipeTo(reverseStream).catch(() => {}); - reverseStream = middleware.reverse.writable; - } + const outputMessageTransformStream = clientOutputTransformStream(); + const inputMessageTransformStream = clientInputTransformStream(method); + const middleware = this.middleware(); // Hooking up agnostic stream side const streamPair = await this.streamPairCreateCallback(); void streamPair.readable - .pipeThrough( - rpcUtils.binaryToJsonMessageStream(rpcUtils.parseJsonRpcResponse), - ) - .pipeTo(reverseStream) + .pipeThrough(middleware.reverse) + .pipeTo(outputMessageTransformStream.writable) .catch(() => {}); - void forwardStream - .pipeThrough(rpcUtils.jsonMessageToBinaryStream()) + void inputMessageTransformStream.readable + .pipeThrough(middleware.forward) .pipeTo(streamPair.writable) .catch(() => {}); // Returning interface return { - readable: outputMessageTransforStream.readable, + readable: outputMessageTransformStream.readable, writable: inputMessageTransformStream.writable, }; } @@ -273,32 +286,6 @@ class RPCClient { writable: callerInterface.writable, }; } - - protected middleware: Array< - MiddlewareFactory< - JsonRpcRequest, - JsonRpcRequest, - JsonRpcResponse, - JsonRpcResponse - > - > = []; - - @ready(new rpcErrors.ErrorRpcDestroyed()) - public registerMiddleware( - middlewareFactory: MiddlewareFactory< - JsonRpcRequest, - JsonRpcRequest, - JsonRpcResponse, - JsonRpcResponse - >, - ) { - this.middleware.push(middlewareFactory); - } - - @ready(new rpcErrors.ErrorRpcDestroyed()) - public clearMiddleware() { - this.middleware = []; - } } export default RPCClient; diff --git a/src/RPC/RPCServer.ts b/src/RPC/RPCServer.ts index 78713736f..a25813ae7 100644 --- a/src/RPC/RPCServer.ts +++ b/src/RPC/RPCServer.ts @@ -37,7 +37,7 @@ interface RPCServer extends CreateDestroy {} class RPCServer { static async createRPCServer({ manifest, - middleware = rpcUtils.defaultMiddlewareWrapper(), + middleware = rpcUtils.defaultServerMiddlewareWrapper(), logger = new Logger(this.name), }: { manifest: ServerManifest; diff --git a/src/RPC/utils.ts b/src/RPC/utils.ts index dad91d132..4a069e3f1 100644 --- a/src/RPC/utils.ts +++ b/src/RPC/utils.ts @@ -586,7 +586,7 @@ const defaultMiddleware: MiddlewareFactory< }; }; -const defaultMiddlewareWrapper = ( +const defaultServerMiddlewareWrapper = ( middleware: MiddlewareFactory< JsonRpcRequest, JsonRpcRequest, @@ -627,6 +627,50 @@ const defaultMiddlewareWrapper = ( }; }; +const defaultClientMiddlewareWrapper = ( + middleware: MiddlewareFactory< + JsonRpcRequest, + JsonRpcRequest, + JsonRpcResponse, + JsonRpcResponse + > = defaultMiddleware, +): MiddlewareFactory< + Uint8Array, + JsonRpcRequest, + JsonRpcResponse, + Uint8Array +> => { + return () => { + const outputTransformStream = binaryToJsonMessageStream( + parseJsonRpcResponse, + undefined, + ); + const inputTransformStream = new TransformStream< + JsonRpcRequest, + JsonRpcRequest + >(); + + const middleMiddleware = middleware(); + const forwardReadable = inputTransformStream.readable + .pipeThrough(middleMiddleware.forward) // Usual middleware here + .pipeThrough(jsonMessageToBinaryStream()); + const reverseReadable = outputTransformStream.readable.pipeThrough( + middleMiddleware.reverse, + ); // Usual middleware here + + return { + forward: { + readable: forwardReadable, + writable: inputTransformStream.writable, + }, + reverse: { + readable: reverseReadable, + writable: outputTransformStream.writable, + }, + }; + }; +}; + export { binaryToJsonMessageStream, jsonMessageToBinaryStream, @@ -648,5 +692,6 @@ export { extractFirstMessageTransform, getHandlerTypes, defaultMiddleware, - defaultMiddlewareWrapper, + defaultServerMiddlewareWrapper, + defaultClientMiddlewareWrapper, }; diff --git a/tests/RPC/RPCClient.test.ts b/tests/RPC/RPCClient.test.ts index ae2fda61e..d675d407e 100644 --- a/tests/RPC/RPCClient.test.ts +++ b/tests/RPC/RPCClient.test.ts @@ -18,6 +18,7 @@ import { ServerCaller, UnaryCaller, } from '@/RPC/callers'; +import * as rpcUtils from '@/RPC/utils'; import * as rpcTestUtils from './utils'; describe(`${RPCClient.name}`, () => { @@ -309,22 +310,22 @@ describe(`${RPCClient.name}`, () => { const rpcClient = await RPCClient.createRPCClient({ manifest: {}, streamPairCreateCallback: async () => streamPair, + middleware: rpcUtils.defaultClientMiddlewareWrapper(() => { + return { + forward: new TransformStream({ + transform: (chunk, controller) => { + controller.enqueue({ + ...chunk, + params: 'one', + }); + }, + }), + reverse: new TransformStream(), + }; + }), logger, }); - rpcClient.registerMiddleware(() => { - return { - forward: new TransformStream({ - transform: (chunk, controller) => { - controller.enqueue({ - ...chunk, - params: 'one', - }); - }, - }), - reverse: new TransformStream(), - }; - }); const callerInterface = await rpcClient.rawDuplexStreamCaller< JSONValue, JSONValue @@ -373,22 +374,22 @@ describe(`${RPCClient.name}`, () => { const rpcClient = await RPCClient.createRPCClient({ manifest: {}, streamPairCreateCallback: async () => streamPair, + middleware: rpcUtils.defaultClientMiddlewareWrapper(() => { + return { + forward: new TransformStream(), + reverse: new TransformStream({ + transform: (chunk, controller) => { + controller.enqueue({ + ...chunk, + result: 'one', + }); + }, + }), + }; + }), logger, }); - rpcClient.registerMiddleware(() => { - return { - forward: new TransformStream(), - reverse: new TransformStream({ - transform: (chunk, controller) => { - controller.enqueue({ - ...chunk, - result: 'one', - }); - }, - }), - }; - }); const callerInterface = await rpcClient.rawDuplexStreamCaller< JSONValue, JSONValue diff --git a/tests/RPC/RPCServer.test.ts b/tests/RPC/RPCServer.test.ts index a613e2ec5..02249df46 100644 --- a/tests/RPC/RPCServer.test.ts +++ b/tests/RPC/RPCServer.test.ts @@ -456,7 +456,7 @@ describe(`${RPCServer.name}`, () => { } }; } - const middleware = rpcUtils.defaultMiddlewareWrapper(() => { + const middleware = rpcUtils.defaultServerMiddlewareWrapper(() => { return { forward: new TransformStream({ transform: (chunk, controller) => { @@ -501,7 +501,7 @@ describe(`${RPCServer.name}`, () => { } }; } - const middleware = rpcUtils.defaultMiddlewareWrapper(() => { + const middleware = rpcUtils.defaultServerMiddlewareWrapper(() => { return { forward: new TransformStream(), reverse: new TransformStream({ @@ -549,7 +549,7 @@ describe(`${RPCServer.name}`, () => { } }; } - const middleware = rpcUtils.defaultMiddlewareWrapper(() => { + const middleware = rpcUtils.defaultServerMiddlewareWrapper(() => { let first = true; let reverseController: TransformStreamDefaultController; return { diff --git a/tests/clientRPC/handlers/agentUnlock.test.ts b/tests/clientRPC/handlers/agentUnlock.test.ts index 1a533a070..45939e327 100644 --- a/tests/clientRPC/handlers/agentUnlock.test.ts +++ b/tests/clientRPC/handlers/agentUnlock.test.ts @@ -97,7 +97,7 @@ describe('agentUnlock', () => { manifest: { agentUnlock: new AgentUnlockHandler({ logger }), }, - middleware: rpcUtils.defaultMiddlewareWrapper( + middleware: rpcUtils.defaultServerMiddlewareWrapper( clientRPCUtils.authenticationMiddlewareServer(sessionManager, keyRing), ), logger, @@ -118,11 +118,11 @@ describe('agentUnlock', () => { logger.getChild('client'), ); }, + middleware: rpcUtils.defaultClientMiddlewareWrapper( + clientRPCUtils.authenticationMiddlewareClient(session), + ), logger, }); - rpcClient.registerMiddleware( - clientRPCUtils.authenticationMiddlewareClient(session), - ); // Doing the test const result = await rpcClient.methods.agentUnlock({ From 8210075b2af8645f6bc115d7a08f3c187f9481cc Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Fri, 10 Feb 2023 18:00:09 +1100 Subject: [PATCH 37/44] fix: converting generators to async iterables Related #500 Related #501 [ci skip] --- src/RPC/RPCClient.ts | 6 +++--- src/RPC/RPCServer.ts | 2 +- src/RPC/types.ts | 12 ++++++------ 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/RPC/RPCClient.ts b/src/RPC/RPCClient.ts index e75d1b54b..242ecf8be 100644 --- a/src/RPC/RPCClient.ts +++ b/src/RPC/RPCClient.ts @@ -166,7 +166,7 @@ class RPCClient { public async serverStreamCaller( method: string, parameters: I, - ): Promise> { + ): Promise> { const callerInterface = await this.rawDuplexStreamCaller(method); const writer = callerInterface.writable.getWriter(); await writer.write(parameters); @@ -183,7 +183,7 @@ class RPCClient { @ready(new rpcErrors.ErrorRpcDestroyed()) public async clientStreamCaller( method: string, - f: (output: Promise) => AsyncGenerator, + f: (output: Promise) => AsyncIterable, ): Promise { const callerInterface = await this.rawClientStreamCaller(method); const writer = callerInterface.writable.getWriter(); @@ -203,7 +203,7 @@ class RPCClient { @ready(new rpcErrors.ErrorRpcDestroyed()) public async duplexStreamCaller( method: string, - f: (output: AsyncGenerator) => AsyncGenerator, + f: (output: AsyncIterable) => AsyncIterable, ): Promise { const callerInterface = await this.rawDuplexStreamCaller(method); const outputGenerator = async function* () { diff --git a/src/RPC/RPCServer.ts b/src/RPC/RPCServer.ts index a25813ae7..d51d5bf19 100644 --- a/src/RPC/RPCServer.ts +++ b/src/RPC/RPCServer.ts @@ -154,7 +154,7 @@ class RPCServer { const events = this.events; const outputGen = async function* (): AsyncGenerator { if (ctx.signal.aborted) throw ctx.signal.reason; - const dataGen = async function* () { + const dataGen = async function* (): AsyncIterable { for await (const data of forwardStream) { yield data.params as I; } diff --git a/src/RPC/types.ts b/src/RPC/types.ts index 336a52663..b169f4c74 100644 --- a/src/RPC/types.ts +++ b/src/RPC/types.ts @@ -121,15 +121,15 @@ type RawHandlerImplementation = HandlerImplementation< type DuplexHandlerImplementation< I extends JSONValue = JSONValue, O extends JSONValue = JSONValue, -> = HandlerImplementation, AsyncGenerator>; +> = HandlerImplementation, AsyncIterable>; type ServerHandlerImplementation< I extends JSONValue = JSONValue, O extends JSONValue = JSONValue, -> = HandlerImplementation>; +> = HandlerImplementation>; type ClientHandlerImplementation< I extends JSONValue = JSONValue, O extends JSONValue = JSONValue, -> = HandlerImplementation, Promise>; +> = HandlerImplementation, Promise>; type UnaryHandlerImplementation< I extends JSONValue = JSONValue, O extends JSONValue = JSONValue, @@ -156,17 +156,17 @@ type UnaryCallerImplementation< type ServerCallerImplementation< I extends JSONValue = JSONValue, O extends JSONValue = JSONValue, -> = (parameters: I) => Promise>; +> = (parameters: I) => Promise>; type ClientCallerImplementation< I extends JSONValue = JSONValue, O extends JSONValue = JSONValue, -> = (f: (output: Promise) => AsyncGenerator) => Promise; +> = (f: (output: Promise) => AsyncIterable) => Promise; type DuplexCallerImplementation< I extends JSONValue = JSONValue, O extends JSONValue = JSONValue, -> = (f: (output: AsyncGenerator) => AsyncGenerator) => Promise; +> = (f: (output: AsyncIterable) => AsyncIterable) => Promise; // Raw callers From 427adaccf792683f967d567e254af7f16c63ed9f Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Fri, 10 Feb 2023 19:02:10 +1100 Subject: [PATCH 38/44] feat: handler class using abstract function Related #500 [ci skip] --- src/RPC/RPCServer.ts | 27 +++++-- src/RPC/handlers.ts | 43 +++++++--- src/clientRPC/handlers/agentStatus.ts | 8 +- src/clientRPC/handlers/agentUnlock.ts | 18 ++--- tests/RPC/RPC.test.ts | 52 ++++++------ tests/RPC/RPCServer.test.ts | 111 +++++++++++++------------- tests/clientRPC/websocket.test.ts | 11 +-- 7 files changed, 146 insertions(+), 124 deletions(-) diff --git a/src/RPC/RPCServer.ts b/src/RPC/RPCServer.ts index d51d5bf19..e05dcbe6b 100644 --- a/src/RPC/RPCServer.ts +++ b/src/RPC/RPCServer.ts @@ -87,27 +87,42 @@ class RPCServer { }) { for (const [key, manifestItem] of Object.entries(manifest)) { if (manifestItem instanceof RawHandler) { - this.registerRawStreamHandler(key, manifestItem.handle); + this.registerRawStreamHandler( + key, + manifestItem.handle.bind(manifestItem), + ); continue; } if (manifestItem instanceof DuplexHandler) { - this.registerDuplexStreamHandler(key, manifestItem.handle); + this.registerDuplexStreamHandler( + key, + manifestItem.handle.bind(manifestItem), + ); continue; } if (manifestItem instanceof ServerHandler) { - this.registerServerStreamHandler(key, manifestItem.handle); + this.registerServerStreamHandler( + key, + manifestItem.handle.bind(manifestItem), + ); continue; } if (manifestItem instanceof ClientHandler) { - this.registerClientStreamHandler(key, manifestItem.handle); + this.registerClientStreamHandler( + key, + manifestItem.handle.bind(manifestItem), + ); continue; } if (manifestItem instanceof ClientHandler) { - this.registerClientStreamHandler(key, manifestItem.handle); + this.registerClientStreamHandler( + key, + manifestItem.handle.bind(manifestItem), + ); continue; } if (manifestItem instanceof UnaryHandler) { - this.registerUnaryHandler(key, manifestItem.handle); + this.registerUnaryHandler(key, manifestItem.handle.bind(manifestItem)); continue; } never(); diff --git a/src/RPC/handlers.ts b/src/RPC/handlers.ts index 86d1ba149..c738c74e8 100644 --- a/src/RPC/handlers.ts +++ b/src/RPC/handlers.ts @@ -1,12 +1,9 @@ import type { JSONValue } from 'types'; -import type { - ClientHandlerImplementation, - DuplexHandlerImplementation, - RawHandlerImplementation, - ServerHandlerImplementation, - UnaryHandlerImplementation, - ContainerType, -} from 'RPC/types'; +import type { ContainerType } from 'RPC/types'; +import type { ReadableStream } from 'stream/web'; +import type { JsonRpcRequest } from 'RPC/types'; +import type { ConnectionInfo } from '../network/types'; +import type { ContextCancellable } from '../contexts/types'; abstract class Handler< Container extends ContainerType = ContainerType, @@ -22,7 +19,11 @@ abstract class Handler< abstract class RawHandler< Container extends ContainerType = ContainerType, > extends Handler { - abstract handle: RawHandlerImplementation; + abstract handle( + input: [ReadableStream, JsonRpcRequest], + connectionInfo: ConnectionInfo, + ctx: ContextCancellable, + ): ReadableStream; } abstract class DuplexHandler< @@ -30,7 +31,11 @@ abstract class DuplexHandler< Input extends JSONValue = JSONValue, Output extends JSONValue = JSONValue, > extends Handler { - abstract handle: DuplexHandlerImplementation; + abstract handle( + input: AsyncIterable, + connectionInfo: ConnectionInfo, + ctx: ContextCancellable, + ): AsyncIterable; } abstract class ServerHandler< @@ -38,7 +43,11 @@ abstract class ServerHandler< Input extends JSONValue = JSONValue, Output extends JSONValue = JSONValue, > extends Handler { - abstract handle: ServerHandlerImplementation; + abstract handle( + input: Input, + connectionInfo: ConnectionInfo, + ctx: ContextCancellable, + ): AsyncIterable; } abstract class ClientHandler< @@ -46,7 +55,11 @@ abstract class ClientHandler< Input extends JSONValue = JSONValue, Output extends JSONValue = JSONValue, > extends Handler { - abstract handle: ClientHandlerImplementation; + abstract handle( + input: AsyncIterable, + connectionInfo: ConnectionInfo, + ctx: ContextCancellable, + ): Promise; } abstract class UnaryHandler< @@ -54,7 +67,11 @@ abstract class UnaryHandler< Input extends JSONValue = JSONValue, Output extends JSONValue = JSONValue, > extends Handler { - abstract handle: UnaryHandlerImplementation; + abstract handle( + input: Input, + connectionInfo: ConnectionInfo, + ctx: ContextCancellable, + ): Promise; } export { diff --git a/src/clientRPC/handlers/agentStatus.ts b/src/clientRPC/handlers/agentStatus.ts index 9bfb0a7d7..83d37fb8e 100644 --- a/src/clientRPC/handlers/agentStatus.ts +++ b/src/clientRPC/handlers/agentStatus.ts @@ -1,4 +1,3 @@ -import type { UnaryHandlerImplementation } from '../../RPC/types'; import type KeyRing from '../../keys/KeyRing'; import type CertManager from '../../keys/CertManager'; import type Logger from '@matrixai/logger'; @@ -29,10 +28,7 @@ class AgentStatusHandler extends UnaryHandler< WithMetadata, WithMetadata > { - public handle: UnaryHandlerImplementation< - WithMetadata, - WithMetadata - > = async () => { + public async handle(): Promise> { return { pid: process.pid, nodeId: nodesUtils.encodeNodeId(this.container.keyRing.getNodeId()), @@ -40,7 +36,7 @@ class AgentStatusHandler extends UnaryHandler< keysUtils.publicKeyToJWK(this.container.keyRing.keyPair.publicKey), ), }; - }; + } } export { AgentStatusHandler, agentStatusCaller }; diff --git a/src/clientRPC/handlers/agentUnlock.ts b/src/clientRPC/handlers/agentUnlock.ts index b4f7febe9..dce46357f 100644 --- a/src/clientRPC/handlers/agentUnlock.ts +++ b/src/clientRPC/handlers/agentUnlock.ts @@ -1,4 +1,3 @@ -import type { UnaryHandlerImplementation } from '../../RPC/types'; import type Logger from '@matrixai/logger'; import type { WithMetadata } from '../types'; import { UnaryHandler } from '../../RPC/handlers'; @@ -11,17 +10,12 @@ class AgentUnlockHandler extends UnaryHandler< WithMetadata, WithMetadata > { - public handle: UnaryHandlerImplementation = - async () => { - // This is a NOP handler, - // authentication and unlocking is handled via middleware. - // Failure to authenticate will be an error from the middleware layer. - return { - metadata: { - Authorization: '', - }, - }; - }; + public async handle(): Promise { + // This is a NOP handler, + // authentication and unlocking is handled via middleware. + // Failure to authenticate will be an error from the middleware layer. + return {}; + } } export { agentUnlockCaller, AgentUnlockHandler }; diff --git a/tests/RPC/RPC.test.ts b/tests/RPC/RPC.test.ts index 2793e24de..5cd72d3c5 100644 --- a/tests/RPC/RPC.test.ts +++ b/tests/RPC/RPC.test.ts @@ -1,13 +1,7 @@ -import type { - ClientHandlerImplementation, - ContainerType, - DuplexHandlerImplementation, - JsonRpcRequest, - RawHandlerImplementation, - ServerHandlerImplementation, - UnaryHandlerImplementation, -} from '@/RPC/types'; +import type { ContainerType, JsonRpcRequest } from '@/RPC/types'; import type { ConnectionInfo } from '@/network/types'; +import type { ReadableStream } from 'stream/web'; +import type { JSONValue } from '@/types'; import { fc, testProp } from '@fast-check/jest'; import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; import RPCServer from '@/RPC/RPCServer'; @@ -44,10 +38,13 @@ describe('RPC', () => { let header: JsonRpcRequest | undefined; class TestMethod extends RawHandler { - public handle: RawHandlerImplementation = ([input, header_]) => { + public handle( + input: [ReadableStream, JsonRpcRequest], + ): ReadableStream { + const [stream, header_] = input; header = header_; - return input; - }; + return stream; + } } const rpcServer = await RPCServer.createRPCServer({ manifest: { @@ -96,11 +93,11 @@ describe('RPC', () => { Uint8Array >(); class TestMethod extends DuplexHandler { - public handle: DuplexHandlerImplementation = async function* (input) { - for await (const val of input) { - yield val; - } - }; + public async *handle( + input: AsyncIterable, + ): AsyncIterable { + yield* input; + } } const rpcServer = await RPCServer.createRPCServer({ manifest: { @@ -143,12 +140,11 @@ describe('RPC', () => { >(); class TestMethod extends ServerHandler { - public handle: ServerHandlerImplementation = - async function* (input) { - for (let i = 0; i < input; i++) { - yield i; - } - }; + public async *handle(input: number): AsyncIterable { + for (let i = 0; i < input; i++) { + yield i; + } + } } const rpcServer = await RPCServer.createRPCServer({ @@ -188,15 +184,13 @@ describe('RPC', () => { >(); class TestMethod extends ClientHandler { - public handle: ClientHandlerImplementation = async ( - input, - ) => { + public async handle(input: AsyncIterable): Promise { let acc = 0; for await (const number of input) { acc += number; } return acc; - }; + } } const rpcServer = await RPCServer.createRPCServer({ manifest: { @@ -239,7 +233,9 @@ describe('RPC', () => { >(); class TestMethod extends UnaryHandler { - public handle: UnaryHandlerImplementation = async (input) => input; + public async handle(input: JSONValue): Promise { + return input; + } } const rpcServer = await RPCServer.createRPCServer({ manifest: { diff --git a/tests/RPC/RPCServer.test.ts b/tests/RPC/RPCServer.test.ts index 02249df46..ed6412b36 100644 --- a/tests/RPC/RPCServer.test.ts +++ b/tests/RPC/RPCServer.test.ts @@ -1,18 +1,14 @@ import type { - ClientHandlerImplementation, ContainerType, - DuplexHandlerImplementation, JsonRpcRequest, JsonRpcResponse, JsonRpcResponseError, - RawHandlerImplementation, - ServerHandlerImplementation, - UnaryHandlerImplementation, } from '@/RPC/types'; import type { JSONValue } from '@/types'; import type { ConnectionInfo, Host, Port } from '@/network/types'; import type { NodeId } from '@/ids'; import type { ReadableWritablePair } from 'stream/web'; +import type { ContextCancellable } from '@/contexts/types'; import { TransformStream, ReadableStream } from 'stream/web'; import { fc, testProp } from '@fast-check/jest'; import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; @@ -74,7 +70,7 @@ describe(`${RPCServer.name}`, () => { rpcTestUtils.binaryStreamToSnippedStream([4, 7, 13, 2, 6]), ); class TestHandler extends RawHandler { - public handle: RawHandlerImplementation = ([input]) => { + public handle([input, _header]): ReadableStream { void (async () => { for await (const _ of input) { // No touch, only consume @@ -86,7 +82,7 @@ describe(`${RPCServer.name}`, () => { controller.close(); }, }); - }; + } } const rpcServer = await RPCServer.createRPCServer({ manifest: { @@ -111,12 +107,14 @@ describe(`${RPCServer.name}`, () => { async (messages) => { const stream = rpcTestUtils.messagesToReadableStream(messages); class TestMethod extends DuplexHandler { - public handle: DuplexHandlerImplementation = async function* (input) { + public async *handle( + input: AsyncIterable, + ): AsyncIterable { for await (const val of input) { yield val; break; } - }; + } } const rpcServer = await RPCServer.createRPCServer({ manifest: { @@ -140,13 +138,15 @@ describe(`${RPCServer.name}`, () => { async (messages) => { const stream = rpcTestUtils.messagesToReadableStream(messages); class TestMethod extends ClientHandler { - public handle: ClientHandlerImplementation = async function (input) { + public async handle( + input: AsyncIterable, + ): Promise { let count = 0; for await (const _ of input) { count += 1; } return count; - }; + } } const rpcServer = await RPCServer.createRPCServer({ manifest: { @@ -170,12 +170,11 @@ describe(`${RPCServer.name}`, () => { async (messages) => { const stream = rpcTestUtils.messagesToReadableStream(messages); class TestMethod extends ServerHandler { - public handle: ServerHandlerImplementation = - async function* (input) { - for (let i = 0; i < input; i++) { - yield i; - } - }; + public async *handle(input: number): AsyncIterable { + for (let i = 0; i < input; i++) { + yield i; + } + } } const rpcServer = await RPCServer.createRPCServer({ manifest: { @@ -199,7 +198,9 @@ describe(`${RPCServer.name}`, () => { async (messages) => { const stream = rpcTestUtils.messagesToReadableStream(messages); class TestMethod extends UnaryHandler { - public handle: UnaryHandlerImplementation = async (input) => input; + public async handle(input: JSONValue): Promise { + return input; + } } const rpcServer = await RPCServer.createRPCServer({ manifest: { @@ -228,12 +229,14 @@ describe(`${RPCServer.name}`, () => { C: Symbol('c'), }; class TestMethod extends DuplexHandler { - public handle: DuplexHandlerImplementation = async function* (input) { + public async *handle( + input: AsyncIterable, + ): AsyncIterable { expect(this.container).toBe(container); for await (const val of input) { yield val; } - }; + } } const rpcServer = await RPCServer.createRPCServer({ @@ -267,15 +270,15 @@ describe(`${RPCServer.name}`, () => { }; let handledConnectionInfo; class TestMethod extends DuplexHandler { - public handle: DuplexHandlerImplementation = async function* ( - input, - connectionInfo_, - ) { + public async *handle( + input: AsyncIterable, + connectionInfo_: ConnectionInfo, + ): AsyncIterable { handledConnectionInfo = connectionInfo_; for await (const val of input) { yield val; } - }; + } } const rpcServer = await RPCServer.createRPCServer({ manifest: { @@ -294,21 +297,19 @@ describe(`${RPCServer.name}`, () => { expect(handledConnectionInfo).toBe(connectionInfo); }, ); - // Problem with the tap stream. It seems to block the whole stream. - // If I don't pipe the tap to the output we actually iterate over some data. testProp('Handler can be aborted', [specificMessageArb], async (messages) => { const stream = rpcTestUtils.messagesToReadableStream(messages); class TestMethod extends DuplexHandler { - public handle: DuplexHandlerImplementation = async function* ( - input, - _connectionInf, - ctx, - ) { + public async *handle( + input: AsyncIterable, + connectionInfo_: ConnectionInfo, + ctx: ContextCancellable, + ): AsyncIterable { for await (const val of input) { if (ctx.signal.aborted) throw ctx.signal.reason; yield val; } - }; + } } const rpcServer = await RPCServer.createRPCServer({ manifest: { @@ -350,11 +351,13 @@ describe(`${RPCServer.name}`, () => { testProp('Handler yields nothing', [specificMessageArb], async (messages) => { const stream = rpcTestUtils.messagesToReadableStream(messages); class TestMethod extends DuplexHandler { - public handle: DuplexHandlerImplementation = async function* (input) { + public async *handle( + input: AsyncIterable, + ): AsyncIterable { for await (const _ of input) { // Do nothing, just consume } - }; + } } const rpcServer = await RPCServer.createRPCServer({ manifest: { @@ -378,9 +381,9 @@ describe(`${RPCServer.name}`, () => { async (messages, error) => { const stream = rpcTestUtils.messagesToReadableStream(messages); class TestMethod extends DuplexHandler { - public handle: DuplexHandlerImplementation = async function* () { + public async *handle(): AsyncIterable { throw error; - }; + } } const rpcServer = await RPCServer.createRPCServer({ manifest: { @@ -416,9 +419,9 @@ describe(`${RPCServer.name}`, () => { async (messages) => { const stream = rpcTestUtils.messagesToReadableStream(messages); class TestMethod extends DuplexHandler { - public handle: DuplexHandlerImplementation = async function* () { + public async *handle(): AsyncIterable { throw new rpcErrors.ErrorRpcPlaceholderConnectionError(); - }; + } } const rpcServer = await RPCServer.createRPCServer({ manifest: { @@ -450,11 +453,11 @@ describe(`${RPCServer.name}`, () => { testProp('forward middlewares', [specificMessageArb], async (messages) => { const stream = rpcTestUtils.messagesToReadableStream(messages); class TestMethod extends DuplexHandler { - public handle: DuplexHandlerImplementation = async function* (input) { - for await (const val of input) { - yield val; - } - }; + public async *handle( + input: AsyncIterable, + ): AsyncIterable { + yield* input; + } } const middleware = rpcUtils.defaultServerMiddlewareWrapper(() => { return { @@ -495,11 +498,11 @@ describe(`${RPCServer.name}`, () => { testProp('reverse middlewares', [specificMessageArb], async (messages) => { const stream = rpcTestUtils.messagesToReadableStream(messages); class TestMethod extends DuplexHandler { - public handle: DuplexHandlerImplementation = async function* (input) { - for await (const val of input) { - yield val; - } - }; + public async *handle( + input: AsyncIterable, + ): AsyncIterable { + yield* input; + } } const middleware = rpcUtils.defaultServerMiddlewareWrapper(() => { return { @@ -543,11 +546,11 @@ describe(`${RPCServer.name}`, () => { async (message) => { const stream = rpcTestUtils.messagesToReadableStream([message]); class TestMethod extends DuplexHandler { - public handle: DuplexHandlerImplementation = async function* (input) { - for await (const val of input) { - yield val; - } - }; + public async *handle( + input: AsyncIterable, + ): AsyncIterable { + yield* input; + } } const middleware = rpcUtils.defaultServerMiddlewareWrapper(() => { let first = true; diff --git a/tests/clientRPC/websocket.test.ts b/tests/clientRPC/websocket.test.ts index ce117cffb..813935ab7 100644 --- a/tests/clientRPC/websocket.test.ts +++ b/tests/clientRPC/websocket.test.ts @@ -2,6 +2,7 @@ import type { TLSConfig } from '@/network/types'; import type { Server } from 'https'; import type { WebSocketServer } from 'ws'; import type { ClientManifest } from '@/RPC/types'; +import type { JSONValue } from '@/types'; import fs from 'fs'; import path from 'path'; import os from 'os'; @@ -60,14 +61,14 @@ describe('websocket', () => { test('websocket should work with RPC', async () => { // Setting up server class Test1 extends UnaryHandler { - public handle = async (params) => { - return params; - }; + public async handle(input: JSONValue): Promise { + return input; + } } class Test2 extends UnaryHandler { - public handle = async () => { + public async handle(): Promise { return { hello: 'not world' }; - }; + } } rpcServer = await RPCServer.createRPCServer({ manifest: { From 632281c844779fab692747d28784d95f4b99e971 Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Fri, 10 Feb 2023 19:10:50 +1100 Subject: [PATCH 39/44] feat: updating metadata wrapper type name Related #500 Related #501 Related #502 [ci skip] --- src/clientRPC/handlers/agentStatus.ts | 12 +++++----- src/clientRPC/handlers/agentUnlock.ts | 13 ++++++---- src/clientRPC/types.ts | 10 ++++++-- src/clientRPC/utils.ts | 34 +++++++++++++-------------- 4 files changed, 39 insertions(+), 30 deletions(-) diff --git a/src/clientRPC/handlers/agentStatus.ts b/src/clientRPC/handlers/agentStatus.ts index 83d37fb8e..8e53106f6 100644 --- a/src/clientRPC/handlers/agentStatus.ts +++ b/src/clientRPC/handlers/agentStatus.ts @@ -2,7 +2,7 @@ import type KeyRing from '../../keys/KeyRing'; import type CertManager from '../../keys/CertManager'; import type Logger from '@matrixai/logger'; import type { NodeIdEncoded } from '../../ids'; -import type { WithMetadata } from '../types'; +import type { RPCRequestParams, RPCResponseResult } from '../types'; import * as nodesUtils from '../../nodes/utils'; import * as keysUtils from '../../keys/utils'; import { UnaryHandler } from '../../RPC/handlers'; @@ -15,8 +15,8 @@ type StatusResult = { }; const agentStatusCaller = new UnaryCaller< - WithMetadata, - WithMetadata + RPCRequestParams, + RPCResponseResult >(); class AgentStatusHandler extends UnaryHandler< @@ -25,10 +25,10 @@ class AgentStatusHandler extends UnaryHandler< certManager: CertManager; logger: Logger; }, - WithMetadata, - WithMetadata + RPCRequestParams, + RPCResponseResult > { - public async handle(): Promise> { + public async handle(): Promise> { return { pid: process.pid, nodeId: nodesUtils.encodeNodeId(this.container.keyRing.getNodeId()), diff --git a/src/clientRPC/handlers/agentUnlock.ts b/src/clientRPC/handlers/agentUnlock.ts index dce46357f..e0c6756df 100644 --- a/src/clientRPC/handlers/agentUnlock.ts +++ b/src/clientRPC/handlers/agentUnlock.ts @@ -1,16 +1,19 @@ import type Logger from '@matrixai/logger'; -import type { WithMetadata } from '../types'; +import type { RPCRequestParams, RPCResponseResult } from '../types'; import { UnaryHandler } from '../../RPC/handlers'; import { UnaryCaller } from '../../RPC/callers'; -const agentUnlockCaller = new UnaryCaller(); +const agentUnlockCaller = new UnaryCaller< + RPCRequestParams, + RPCResponseResult +>(); class AgentUnlockHandler extends UnaryHandler< { logger: Logger }, - WithMetadata, - WithMetadata + RPCRequestParams, + RPCResponseResult > { - public async handle(): Promise { + public async handle(): Promise { // This is a NOP handler, // authentication and unlocking is handled via middleware. // Failure to authenticate will be an error from the middleware layer. diff --git a/src/clientRPC/types.ts b/src/clientRPC/types.ts index 8ae04e161..371511752 100644 --- a/src/clientRPC/types.ts +++ b/src/clientRPC/types.ts @@ -3,10 +3,16 @@ import type { JSONValue } from '../types'; // eslint-disable-next-line type NoData = {}; -type WithMetadata = NoData> = { +type RPCRequestParams = NoData> = { metadata?: { [Key: string]: JSONValue; } & Partial<{ Authorization: string }>; } & Omit; -export type { WithMetadata, NoData }; +type RPCResponseResult = NoData> = { + metadata?: { + [Key: string]: JSONValue; + } & Partial<{ Authorization: string }>; +} & Omit; + +export type { RPCRequestParams, RPCResponseResult, NoData }; diff --git a/src/clientRPC/utils.ts b/src/clientRPC/utils.ts index ccd5f8b74..2aaf8df11 100644 --- a/src/clientRPC/utils.ts +++ b/src/clientRPC/utils.ts @@ -2,7 +2,7 @@ import type { SessionToken } from '../sessions/types'; import type KeyRing from '../keys/KeyRing'; import type SessionManager from '../sessions/SessionManager'; import type { Session } from '../sessions'; -import type { WithMetadata } from './types'; +import type { RPCResponseResult, RPCRequestParams } from './types'; import type { JsonRpcRequest, JsonRpcResponse, @@ -25,7 +25,7 @@ import { promise } from '../utils'; async function authenticate( sessionManager: SessionManager, keyRing: KeyRing, - message: JsonRpcRequest, + message: JsonRpcRequest, ) { if (message.params == null) throw new clientErrors.ErrorClientAuthMissing(); if (message.params.metadata == null) { @@ -58,7 +58,7 @@ async function authenticate( return `Bearer ${token}`; } -function decodeAuth(messageParams: WithMetadata) { +function decodeAuth(messageParams: RPCRequestParams) { const auth = messageParams.metadata?.Authorization; if (auth == null || !auth.startsWith('Bearer ')) { return; @@ -75,10 +75,10 @@ function authenticationMiddlewareServer( sessionManager: SessionManager, keyRing: KeyRing, ): MiddlewareFactory< - JsonRpcRequest, - JsonRpcRequest, - JsonRpcResponse, - JsonRpcResponse + JsonRpcRequest, + JsonRpcRequest, + JsonRpcResponse, + JsonRpcResponse > { return () => { let forwardFirst = true; @@ -86,8 +86,8 @@ function authenticationMiddlewareServer( let outgoingToken: string | null = null; return { forward: new TransformStream< - JsonRpcRequest, - JsonRpcRequest + JsonRpcRequest, + JsonRpcRequest >({ transform: async (chunk, controller) => { if (forwardFirst) { @@ -132,17 +132,17 @@ function authenticationMiddlewareServer( function authenticationMiddlewareClient( session: Session, ): MiddlewareFactory< - JsonRpcRequest, - JsonRpcRequest, - JsonRpcResponse, - JsonRpcResponse + JsonRpcRequest, + JsonRpcRequest, + JsonRpcResponse, + JsonRpcResponse > { return () => { let forwardFirst = true; return { forward: new TransformStream< - JsonRpcRequest, - JsonRpcRequest + JsonRpcRequest, + JsonRpcRequest >({ transform: async (chunk, controller) => { if (forwardFirst) { @@ -164,8 +164,8 @@ function authenticationMiddlewareClient( }, }), reverse: new TransformStream< - JsonRpcResponse, - JsonRpcResponse + JsonRpcResponse, + JsonRpcResponse >({ transform: async (chunk, controller) => { controller.enqueue(chunk); From ffb8b57ecc2f17ce5240b8f11e562450672eed36 Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Mon, 13 Feb 2023 13:25:49 +1100 Subject: [PATCH 40/44] fix: changed client stream caller signature Related #501 [ci skip] --- src/RPC/RPCClient.ts | 56 +++++++++++-------------------------- src/RPC/types.ts | 2 +- tests/RPC/RPC.test.ts | 19 ++++++------- tests/RPC/RPCClient.test.ts | 34 +++++++++++----------- 4 files changed, 42 insertions(+), 69 deletions(-) diff --git a/src/RPC/RPCClient.ts b/src/RPC/RPCClient.ts index 242ecf8be..8c01f91e8 100644 --- a/src/RPC/RPCClient.ts +++ b/src/RPC/RPCClient.ts @@ -75,7 +75,7 @@ class RPCClient { case 'SERVER': return (params) => this.serverStreamCaller(method, params); case 'CLIENT': - return (f) => this.clientStreamCaller(method, f); + return () => this.clientStreamCaller(method); case 'DUPLEX': return (f) => this.duplexStreamCaller(method, f); case 'RAW': @@ -142,8 +142,6 @@ class RPCClient { return this.rawMethodsProxy as MapRawCallers; } - // Convenience methods - @ready(new rpcErrors.ErrorRpcDestroyed()) public async unaryCaller( method: string, @@ -183,21 +181,22 @@ class RPCClient { @ready(new rpcErrors.ErrorRpcDestroyed()) public async clientStreamCaller( method: string, - f: (output: Promise) => AsyncIterable, - ): Promise { - const callerInterface = await this.rawClientStreamCaller(method); - const writer = callerInterface.writable.getWriter(); - let running = true; - for await (const value of f(callerInterface.output)) { - if (value === undefined) { - await writer.close(); - running = false; + ): Promise<{ + output: Promise; + writable: WritableStream; + }> { + const callerInterface = await this.rawDuplexStreamCaller(method); + const reader = callerInterface.readable.getReader(); + const output = reader.read().then(({ value, done }) => { + if (done) { + throw new rpcErrors.ErrorRpcRemoteError('Stream ended before response'); } - // Write while running otherwise consume until ended - if (running) await writer.write(value); - } - // If ended before finish running then close writer - if (running) await writer.close(); + return value; + }); + return { + output, + writable: callerInterface.writable, + }; } @ready(new rpcErrors.ErrorRpcDestroyed()) @@ -263,29 +262,6 @@ class RPCClient { tempWriter.releaseLock(); return streamPair; } - - protected async rawClientStreamCaller< - I extends JSONValue, - O extends JSONValue, - >( - method: string, - ): Promise<{ - output: Promise; - writable: WritableStream; - }> { - const callerInterface = await this.rawDuplexStreamCaller(method); - const reader = callerInterface.readable.getReader(); - const output = reader.read().then(({ value, done }) => { - if (done) { - throw new rpcErrors.ErrorRpcRemoteError('Stream ended before response'); - } - return value; - }); - return { - output, - writable: callerInterface.writable, - }; - } } export default RPCClient; diff --git a/src/RPC/types.ts b/src/RPC/types.ts index b169f4c74..4e6633527 100644 --- a/src/RPC/types.ts +++ b/src/RPC/types.ts @@ -161,7 +161,7 @@ type ServerCallerImplementation< type ClientCallerImplementation< I extends JSONValue = JSONValue, O extends JSONValue = JSONValue, -> = (f: (output: Promise) => AsyncIterable) => Promise; +> = () => Promise<{ output: Promise; writable: WritableStream }>; type DuplexCallerImplementation< I extends JSONValue = JSONValue, diff --git a/tests/RPC/RPC.test.ts b/tests/RPC/RPC.test.ts index 5cd72d3c5..49b4e078c 100644 --- a/tests/RPC/RPC.test.ts +++ b/tests/RPC/RPC.test.ts @@ -208,17 +208,14 @@ describe('RPC', () => { logger, }); - await rpcClient.methods.testMethod(async function* (output) { - for (const value of values) { - yield value; - } - // Ending writes - yield undefined; - // Checking output - const expectedResult = values.reduce((p, c) => p + c); - await expect(output).resolves.toEqual(expectedResult); - }); - + const { output, writable } = await rpcClient.methods.testMethod(); + const writer = writable.getWriter(); + for (const value of values) { + await writer.write(value); + } + await writer.close(); + const expectedResult = values.reduce((p, c) => p + c); + await expect(output).resolves.toEqual(expectedResult); await rpcServer.destroy(); await rpcClient.destroy(); }, diff --git a/tests/RPC/RPCClient.test.ts b/tests/RPC/RPCClient.test.ts index d675d407e..63da3ba0f 100644 --- a/tests/RPC/RPCClient.test.ts +++ b/tests/RPC/RPCClient.test.ts @@ -175,16 +175,16 @@ describe(`${RPCClient.name}`, () => { streamPairCreateCallback: async () => streamPair, logger, }); - await rpcClient.clientStreamCaller( - methodName, - async function* (output) { - for (const param of params) { - yield param; - } - yield undefined; - expect(await output).toStrictEqual(message.result); - }, - ); + const { output, writable } = await rpcClient.clientStreamCaller< + JSONValue, + JSONValue + >(methodName); + const writer = writable.getWriter(); + for (const param of params) { + await writer.write(param); + } + await writer.close(); + expect(await output).toStrictEqual(message.result); const expectedOutput = params.map((v) => JSON.stringify({ method: methodName, @@ -518,13 +518,13 @@ describe(`${RPCClient.name}`, () => { streamPairCreateCallback: async () => streamPair, logger, }); - await rpcClient.methods.client(async function* (output) { - for (const param of params) { - yield param; - } - yield undefined; - expect(await output).toStrictEqual(message.result); - }); + const { output, writable } = await rpcClient.methods.client(); + const writer = writable.getWriter(); + for (const param of params) { + await writer.write(param); + } + expect(await output).toStrictEqual(message.result); + await writer.close(); const expectedOutput = params.map((v) => JSON.stringify({ method: 'client', From 5b3604992bfb34161da3ec4e14453d2c1d337c71 Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Mon, 13 Feb 2023 16:25:09 +1100 Subject: [PATCH 41/44] tests: expanding error tests [ci skip] --- src/RPC/RPCServer.ts | 9 +- src/errors.ts | 1 + tests/RPC/RPC.test.ts | 96 ++++++++++++++++++++ tests/RPC/RPCClient.test.ts | 75 ++++++++++++++- tests/RPC/RPCServer.test.ts | 52 +++++++++-- tests/RPC/utils.ts | 28 +++++- tests/clientRPC/handlers/agentStatus.test.ts | 2 +- 7 files changed, 248 insertions(+), 15 deletions(-) diff --git a/src/RPC/RPCServer.ts b/src/RPC/RPCServer.ts index e05dcbe6b..a85afcd14 100644 --- a/src/RPC/RPCServer.ts +++ b/src/RPC/RPCServer.ts @@ -38,6 +38,7 @@ class RPCServer { static async createRPCServer({ manifest, middleware = rpcUtils.defaultServerMiddlewareWrapper(), + sensitive = false, logger = new Logger(this.name), }: { manifest: ServerManifest; @@ -47,12 +48,14 @@ class RPCServer { Uint8Array, JsonRpcResponseResult >; + sensitive?: boolean; logger?: Logger; }): Promise { logger.info(`Creating ${this.name}`); const rpcServer = new this({ manifest, middleware, + sensitive, logger, }); logger.info(`Created ${this.name}`); @@ -63,6 +66,7 @@ class RPCServer { protected logger: Logger; protected handlerMap: Map = new Map(); protected activeStreams: Set> = new Set(); + protected sensitive: boolean; protected events: EventTarget = new EventTarget(); protected middleware: MiddlewareFactory< JsonRpcRequest, @@ -74,6 +78,7 @@ class RPCServer { public constructor({ manifest, middleware, + sensitive, logger, }: { manifest: ServerManifest; @@ -83,6 +88,7 @@ class RPCServer { Uint8Array, JsonRpcResponseResult >; + sensitive: boolean; logger: Logger; }) { for (const [key, manifestItem] of Object.entries(manifest)) { @@ -128,6 +134,7 @@ class RPCServer { never(); } this.middleware = middleware; + this.sensitive = sensitive; this.logger = logger; } @@ -199,7 +206,7 @@ class RPCServer { const rpcError: JsonRpcError = { code: e.exitCode ?? sysexits.UNKNOWN, message: e.description ?? '', - data: rpcUtils.fromError(e), + data: rpcUtils.fromError(e, this.sensitive), }; const rpcErrorMessage: JsonRpcResponseError = { jsonrpc: '2.0', diff --git a/src/errors.ts b/src/errors.ts index e2114cf55..98a9dc13c 100644 --- a/src/errors.ts +++ b/src/errors.ts @@ -80,3 +80,4 @@ export * from './schema/errors'; export * from './status/errors'; export * from './validation/errors'; export * from './utils/errors'; +export * from './RPC/errors'; diff --git a/tests/RPC/RPC.test.ts b/tests/RPC/RPC.test.ts index 49b4e078c..28c9704e2 100644 --- a/tests/RPC/RPC.test.ts +++ b/tests/RPC/RPC.test.ts @@ -20,6 +20,7 @@ import { ServerCaller, UnaryCaller, } from '@/RPC/callers'; +import * as rpcErrors from '@/RPC/errors'; import * as rpcTestUtils from './utils'; describe('RPC', () => { @@ -256,4 +257,99 @@ describe('RPC', () => { await rpcClient.destroy(); }, ); + testProp( + 'RPC handles and sends errors', + [ + rpcTestUtils.safeJsonValueArb, + rpcTestUtils.errorArb(rpcTestUtils.errorArb()), + ], + async (value, error) => { + const { clientPair, serverPair } = rpcTestUtils.createTapPairs< + Uint8Array, + Uint8Array + >(); + + class TestMethod extends UnaryHandler { + public async handle(): Promise { + throw error; + } + } + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod: new TestMethod({}), + }, + logger, + }); + rpcServer.handleStream(serverPair, {} as ConnectionInfo); + + const rpcClient = await RPCClient.createRPCClient({ + manifest: { + testMethod: new UnaryCaller(), + }, + streamPairCreateCallback: async () => clientPair, + logger, + }); + + const callProm = rpcClient.methods.testMethod(value); + await expect(callProm).rejects.toThrow(rpcErrors.ErrorRpcRemoteError); + await expect( + callProm.catch((e) => { + throw e.cause; + }), + ).rejects.toThrow(error); + expect(await callProm.catch((e) => JSON.stringify(e.cause))).toInclude( + 'stack', + ); + await rpcServer.destroy(); + await rpcClient.destroy(); + }, + ); + testProp( + 'RPC handles and sends sensitive errors', + [ + rpcTestUtils.safeJsonValueArb, + rpcTestUtils.errorArb(rpcTestUtils.errorArb()), + ], + async (value, error) => { + const { clientPair, serverPair } = rpcTestUtils.createTapPairs< + Uint8Array, + Uint8Array + >(); + + class TestMethod extends UnaryHandler { + public async handle(): Promise { + throw error; + } + } + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod: new TestMethod({}), + }, + sensitive: true, + logger, + }); + rpcServer.handleStream(serverPair, {} as ConnectionInfo); + + const rpcClient = await RPCClient.createRPCClient({ + manifest: { + testMethod: new UnaryCaller(), + }, + streamPairCreateCallback: async () => clientPair, + logger, + }); + + const callProm = rpcClient.methods.testMethod(value); + await expect(callProm).rejects.toThrow(rpcErrors.ErrorRpcRemoteError); + await expect( + callProm.catch((e) => { + throw e.cause; + }), + ).rejects.toThrow(error); + expect( + await callProm.catch((e) => JSON.stringify(e.cause)), + ).not.toInclude('stack'); + await rpcServer.destroy(); + await rpcClient.destroy(); + }, + ); }); diff --git a/tests/RPC/RPCClient.test.ts b/tests/RPC/RPCClient.test.ts index 63da3ba0f..90c51335e 100644 --- a/tests/RPC/RPCClient.test.ts +++ b/tests/RPC/RPCClient.test.ts @@ -234,7 +234,80 @@ describe(`${RPCClient.name}`, () => { 'generic duplex caller can throw received error message', [ fc.array(rpcTestUtils.jsonRpcResponseResultArb()), - rpcTestUtils.jsonRpcResponseErrorArb(), + rpcTestUtils.jsonRpcResponseErrorArb(rpcTestUtils.errorArb()), + ], + async (messages, errorMessage) => { + const inputStream = rpcTestUtils.messagesToReadableStream([ + ...messages, + errorMessage, + ]); + const [outputResult, outputStream] = + rpcTestUtils.streamToArray(); + const streamPair: ReadableWritablePair = { + readable: inputStream, + writable: outputStream, + }; + const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, + streamPairCreateCallback: async () => streamPair, + logger, + }); + const callProm = rpcClient.duplexStreamCaller( + methodName, + async function* (output) { + for await (const _ of output) { + // No touch, just consume + } + }, + ); + await expect(callProm).rejects.toThrow(rpcErrors.ErrorRpcRemoteError); + await outputResult; + await rpcClient.destroy(); + }, + ); + testProp( + 'generic duplex caller can throw received error message with sensitive', + [ + fc.array(rpcTestUtils.jsonRpcResponseResultArb()), + rpcTestUtils.jsonRpcResponseErrorArb(rpcTestUtils.errorArb(), true), + ], + async (messages, errorMessage) => { + const inputStream = rpcTestUtils.messagesToReadableStream([ + ...messages, + errorMessage, + ]); + const [outputResult, outputStream] = + rpcTestUtils.streamToArray(); + const streamPair: ReadableWritablePair = { + readable: inputStream, + writable: outputStream, + }; + const rpcClient = await RPCClient.createRPCClient({ + manifest: {}, + streamPairCreateCallback: async () => streamPair, + logger, + }); + const callProm = rpcClient.duplexStreamCaller( + methodName, + async function* (output) { + for await (const _ of output) { + // No touch, just consume + } + }, + ); + await expect(callProm).rejects.toThrow(rpcErrors.ErrorRpcRemoteError); + await outputResult; + await rpcClient.destroy(); + }, + ); + testProp( + 'generic duplex caller can throw received error message with causes', + [ + fc.array(rpcTestUtils.jsonRpcResponseResultArb()), + rpcTestUtils.jsonRpcResponseErrorArb( + rpcTestUtils.errorArb(rpcTestUtils.errorArb()), + true, + ), ], async (messages, errorMessage) => { const inputStream = rpcTestUtils.messagesToReadableStream([ diff --git a/tests/RPC/RPCServer.test.ts b/tests/RPC/RPCServer.test.ts index ed6412b36..504d8b8db 100644 --- a/tests/RPC/RPCServer.test.ts +++ b/tests/RPC/RPCServer.test.ts @@ -44,11 +44,6 @@ describe(`${RPCServer.name}`, () => { maxLength: 10, }, ); - const errorArb = fc.oneof( - fc.constant(new rpcErrors.ErrorRpcParse()), - fc.constant(new rpcErrors.ErrorRpcMessageLength()), - fc.constant(new rpcErrors.ErrorRpcRemoteError()), - ); const validToken = 'VALIDTOKEN'; const invalidTokenMessageArb = rpcTestUtils.jsonRpcRequestMessageArb( fc.constant('testMethod'), @@ -377,7 +372,47 @@ describe(`${RPCServer.name}`, () => { }); testProp( 'should send error message', - [specificMessageArb, errorArb], + [specificMessageArb, rpcTestUtils.errorArb(rpcTestUtils.errorArb())], + async (messages, error) => { + const stream = rpcTestUtils.messagesToReadableStream(messages); + class TestMethod extends DuplexHandler { + public async *handle(): AsyncIterable { + throw error; + } + } + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + testMethod: new TestMethod({}), + }, + logger, + }); + let resolve, reject; + const errorProm = new Promise((resolve_, reject_) => { + resolve = resolve_; + reject = reject_; + }); + rpcServer.addEventListener('error', (thing) => { + resolve(thing); + }); + const [outputResult, outputStream] = rpcTestUtils.streamToArray(); + const readWriteStream: ReadableWritablePair = { + readable: stream, + writable: outputStream, + }; + rpcServer.handleStream(readWriteStream, {} as ConnectionInfo); + const rawErrorMessage = (await outputResult)[0]!.toString(); + expect(rawErrorMessage).toInclude('stack'); + const errorMessage = JSON.parse(rawErrorMessage); + expect(errorMessage.error.code).toEqual(error.exitCode); + expect(errorMessage.error.message).toEqual(error.description); + reject(); + await expect(errorProm).toReject(); + await rpcServer.destroy(); + }, + ); + testProp( + 'should send error message with sensitive', + [specificMessageArb, rpcTestUtils.errorArb(rpcTestUtils.errorArb())], async (messages, error) => { const stream = rpcTestUtils.messagesToReadableStream(messages); class TestMethod extends DuplexHandler { @@ -389,6 +424,7 @@ describe(`${RPCServer.name}`, () => { manifest: { testMethod: new TestMethod({}), }, + sensitive: true, logger, }); let resolve, reject; @@ -405,7 +441,9 @@ describe(`${RPCServer.name}`, () => { writable: outputStream, }; rpcServer.handleStream(readWriteStream, {} as ConnectionInfo); - const errorMessage = JSON.parse((await outputResult)[0]!.toString()); + const rawErrorMessage = (await outputResult)[0]!.toString(); + expect(rawErrorMessage).not.toInclude('stack'); + const errorMessage = JSON.parse(rawErrorMessage); expect(errorMessage.error.code).toEqual(error.exitCode); expect(errorMessage.error.message).toEqual(error.description); reject(); diff --git a/tests/RPC/utils.ts b/tests/RPC/utils.ts index 8cffba4a1..70ef328f2 100644 --- a/tests/RPC/utils.ts +++ b/tests/RPC/utils.ts @@ -14,6 +14,7 @@ import { ReadableStream, WritableStream, TransformStream } from 'stream/web'; import { fc } from '@fast-check/jest'; import * as utils from '@/utils'; import { fromError } from '@/RPC/utils'; +import * as rpcErrors from '@/RPC/errors'; /** * This is used to convert regular chunks into randomly sized chunks based on @@ -141,14 +142,16 @@ const jsonRpcResponseResultArb = ( id: idArb, }) .noShrink() as fc.Arbitrary; - -const jsonRpcErrorArb = (error: Error = new Error('test error')) => +const jsonRpcErrorArb = ( + error: fc.Arbitrary = fc.constant(new Error('test error')), + sensitive: boolean = false, +) => fc .record( { code: fc.integer(), message: fc.string(), - data: fc.constant(fromError(error)), + data: error.map((e) => fromError(e, sensitive)), }, { requiredKeys: ['code', 'message'], @@ -156,11 +159,14 @@ const jsonRpcErrorArb = (error: Error = new Error('test error')) => ) .noShrink() as fc.Arbitrary; -const jsonRpcResponseErrorArb = (error?: Error) => +const jsonRpcResponseErrorArb = ( + error?: fc.Arbitrary, + sensitive: boolean = false, +) => fc .record({ jsonrpc: fc.constant('2.0'), - error: jsonRpcErrorArb(error), + error: jsonRpcErrorArb(error, sensitive), id: idArb, }) .noShrink() as fc.Arbitrary; @@ -250,6 +256,17 @@ function createTapPairs( }; } +const errorArb = ( + cause: fc.Arbitrary = fc.constant(undefined), +) => + cause.chain((cause) => + fc.oneof( + fc.constant(new rpcErrors.ErrorRpcParse(undefined, { cause })), + fc.constant(new rpcErrors.ErrorRpcMessageLength(undefined, { cause })), + fc.constant(new rpcErrors.ErrorRpcRemoteError(undefined, { cause })), + ), + ); + export { binaryStreamToSnippedStream, binaryStreamToNoisyStream, @@ -269,4 +286,5 @@ export { streamToArray, tapTransformStream, createTapPairs, + errorArb, }; diff --git a/tests/clientRPC/handlers/agentStatus.test.ts b/tests/clientRPC/handlers/agentStatus.test.ts index d1278b387..0ec84f1f2 100644 --- a/tests/clientRPC/handlers/agentStatus.test.ts +++ b/tests/clientRPC/handlers/agentStatus.test.ts @@ -79,7 +79,7 @@ describe('agentStatus', () => { recursive: true, }); }); - test('get status %s', async () => { + test('get status', async () => { // Setup const rpcServer = await RPCServer.createRPCServer({ manifest: { From 89c6c598a4ffa81c1766bcc397cf11172caab1b1 Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Mon, 13 Feb 2023 16:32:55 +1100 Subject: [PATCH 42/44] fix: removing old test files [ci skip] --- test-ajv.ts | 37 ---- test-deepkit-rpc-client.ts | 45 ----- test-deepkit-rpc-server.ts | 44 ----- test-dgram.ts | 43 ---- test-g.ts | 22 --- test-generator-exception.ts | 90 --------- test-generators.ts | 377 ------------------------------------ test-gg.ts | 211 -------------------- test-hashing.ts | 37 ---- test-muxrpc-client.ts | 176 ----------------- test-muxrpc-server.ts | 201 ------------------- test-subject.ts | 20 -- 12 files changed, 1303 deletions(-) delete mode 100644 test-ajv.ts delete mode 100644 test-deepkit-rpc-client.ts delete mode 100644 test-deepkit-rpc-server.ts delete mode 100644 test-dgram.ts delete mode 100644 test-g.ts delete mode 100644 test-generator-exception.ts delete mode 100644 test-generators.ts delete mode 100644 test-gg.ts delete mode 100644 test-hashing.ts delete mode 100644 test-muxrpc-client.ts delete mode 100644 test-muxrpc-server.ts delete mode 100644 test-subject.ts diff --git a/test-ajv.ts b/test-ajv.ts deleted file mode 100644 index bec582f79..000000000 --- a/test-ajv.ts +++ /dev/null @@ -1,37 +0,0 @@ -import { signedClaimValidate } from './src/claims/schema'; -import { ClaimIdEncoded, SignedClaim } from './src/claims/types'; -import { NodeIdEncoded } from './src/ids/types'; - -async function main () { - - const y: SignedClaim = { - payload: { - jti: 'abc' as ClaimIdEncoded, - nbf: 123, - iat: 456, - seq: 123, - prevClaimId: 'abc' as ClaimIdEncoded, - prevDigest: null, - iss: 'abc' as NodeIdEncoded, - sub: 'abc', - }, - signatures: [{ - protected: { - alg: "BLAKE2b" - }, - header: { - - }, - signature: "abc", - }] - }; - - const x = signedClaimValidate( - y - ); - - console.log(signedClaimValidate.errors); - -} - -main(); diff --git a/test-deepkit-rpc-client.ts b/test-deepkit-rpc-client.ts deleted file mode 100644 index d6fe1be1f..000000000 --- a/test-deepkit-rpc-client.ts +++ /dev/null @@ -1,45 +0,0 @@ -import { rpc, RpcKernel } from '@deepkit/rpc'; -// import { RpcClient } from '@deepkit/rpc'; -import { RpcWebSocketClient } from '@deepkit/rpc'; -// import { RpcTcpClientAdapter } from '@deepkit/rpc-tcp'; - -interface ControllerI { - hello(title: string): string; - getUser(): Promise; -} - -@rpc.controller('clientController') -class Controller { - @rpc.action() - hello(title: string): string { - return 'Hello ' + title; - } - - @rpc.action() - async getUser(): Promise { - return 'this is a user'; - } -} - -async function main () { - - const client = new RpcWebSocketClient('ws://localhost:8081'); - client.registerController(Controller, 'clientController'); - - const controller = client.controller('myController'); - - - // const result1 = await controller.hello('world'); - // const result2 = await controller.getUser(); - - // console.log(result1); - // console.log(result2); - - // client.disconnect(); -} - -main(); - - - -// instresting diff --git a/test-deepkit-rpc-server.ts b/test-deepkit-rpc-server.ts deleted file mode 100644 index 5c24ff59a..000000000 --- a/test-deepkit-rpc-server.ts +++ /dev/null @@ -1,44 +0,0 @@ -import { rpc, RpcKernel } from '@deepkit/rpc'; -import { RpcWebSocketServer } from '@deepkit/rpc-tcp'; - -@rpc.controller('Con') -class Con { - @rpc.action() - hello(title: string): string { - return 'Hello ' + title; - } - - @rpc.action() - async getUser(): Promise { - return 'this is a user'; - } -} - -async function main () { - - const kernel = new RpcKernel(); - kernel.registerController(Con, 'Con'); - kernel.controllers - kernel.createConnection - kernel.onConnection((conn) => { - conn.clientAddress - conn.controller - conn.handleMessage - conn.myPeerId - conn.onClose - conn.onMessage - conn.writer - }); - - const server = new RpcWebSocketServer(kernel, 'ws://localhost:8081'); - - server.start({ - host: 'localhost', - port: 8081, - }); - - console.log('STARTED'); - // server.close(); -} - -main(); diff --git a/test-dgram.ts b/test-dgram.ts deleted file mode 100644 index f43de4ad4..000000000 --- a/test-dgram.ts +++ /dev/null @@ -1,43 +0,0 @@ -import dgram from 'dgram'; - -// No other process should bebound on it? -// Binding to `::` is the default? -// Right I'm just wondering what it is bound to if we just send -// Default is `dns.lookup` -// The signal can be used to close the socket -const socket = dgram.createSocket('udp4'); - -socket.on('message', (msg, rinfo) => { - console.log(msg, rinfo); -}); - -socket.bind(55555, 'localhost', () => { - - const socket2 = dgram.createSocket('udp4'); - // Upon the first send, it will be bound - // But you can send it to different places - // But you don't have to bind it if you don't want to - // But then it will be randomly set upon the first send and repeatedly - socket2.bind(55551); - - socket2.send('abc', 55555, 'localhost', (e) => { - - console.log('done', e); - socket2.send('abc', 55555, 'localhost', (e) => { - console.log('done', e); - - socket2.send('abc', 55555, 'localhost', (e) => { - console.log('done', e); - - // socket.close(); - // socket2.close(); - - }); - - }); - - - }); - -}); - diff --git a/test-g.ts b/test-g.ts deleted file mode 100644 index 30300ecca..000000000 --- a/test-g.ts +++ /dev/null @@ -1,22 +0,0 @@ -function *concatStrings(): Generator { - let result = ''; - while (true) { - const data = yield; - if (data === null) { - return result; - } - result += data; - } -} - -function *combine() { - return (yield* concatStrings()) + 'FINISH'; -} - -const g = combine(); -g.next(); -g.next("a"); -g.next("b"); -g.next("c"); -const r = g.next(null); -console.log(r.value); diff --git a/test-generator-exception.ts b/test-generator-exception.ts deleted file mode 100644 index c0b51c950..000000000 --- a/test-generator-exception.ts +++ /dev/null @@ -1,90 +0,0 @@ -import process from 'process'; - -process.on('uncaughtException', () => { - console.log('Exception was uncaught'); -}); - -process.on('unhandledRejection', () => { - console.log('Rejection was unhandled'); -}); - -async function sleep(ms: number): Promise { - return await new Promise((r) => setTimeout(r, ms)); -} - -async function *gf1() { - let c = 0; - while (true) { - await sleep(100); - yield 'G1 string'; - if (c === 5) { - throw new Error('There is an Error!'); - } - c++; - } -} - -async function *gf2() { - while (true) { - await sleep(100); - try { - yield 'G2 string'; - } catch (e) { - // This yield is for the `throw` call - // It ends up being AWAITED FOR - yield; - // Then on the NEXT `next` call they will get an error - // That's how it has to work... LOL - throw(new Error('Wrapped Error')); - } - } -} - -async function main () { - const g1 = gf1(); - for (let i = 0; i < 10; i++) { - try { - console.log(await g1.next()); - } catch (e) { - console.log('Consumed an exception!'); - break; - } - } - - const g2 = gf2(); - setTimeout(async () => { - // await g.return(); - // Async generator - // If the thrown error is NOT caught - // this will return a Promise that REJECTS - // with the exception passed in - // void g2.throw(new Error('There is an Error!')).catch((e) => { - // console.log('IGNORING ERROR: ', e.message); - // }); - - console.log(await g2.throw(new Error('There is an Error!'))); - }, 250); - - for (let i = 0; i < 10; i++) { - try { - console.log(await g2.next()); - } catch (e) { - console.log('Consumed an exception!'); - break; - } - } - console.log(await g2.next()); - -} - -void main(); - -// Ok so when the stream has an exception -// If we use async generator throw -// The async generator is being consumed by the end user -// That exception cannot be passed into the `yield` -// Not even if I wait until the next loop -// Because under the while loop it will try to do that - -// The problem is here... the types will be a bit weird though -// So that's what you have to be careful about diff --git a/test-generators.ts b/test-generators.ts deleted file mode 100644 index c17b898d1..000000000 --- a/test-generators.ts +++ /dev/null @@ -1,377 +0,0 @@ -import { Subject } from 'rxjs'; - -// This example demonstrates a simple handler with -// input async generator and output async generator - -async function sleep(ms: number): Promise { - return await new Promise((r) => setTimeout(r, ms)); -} - - -// Echo handler -async function* handler1( - input: AsyncIterableIterator -): AsyncGenerator { - // This will not preserve the `return` - // for await(const chunk of input) { - // yield chunk; - // } - - // This will also not preserve the `return` - // yield* input; - - // If we want to preserve the `return` - // We must use `return` here too - // Note that technically the `any` is required - // At the end, although technically that is not allowed - return yield* input; -} - -async function client1() { - console.log('CLIENT 1 START'); - async function* input() { - yield Buffer.from('hello'); - yield Buffer.from('world'); - return Buffer.from('end'); - } - // Assume the client gets `AsyncIterableIterator` - const output = handler1(input()) as AsyncIterableIterator; - // for await (const chunk of output) { - // console.log(chunk.toString()); - // } - while (true) { - const { done, value } = await output.next(); - if (Buffer.isBuffer(value)) { - console.log(value.toString()); - } else { - console.log('end with nothing'); - } - if (done) { - break; - } - } - console.log('CLIENT 1 STOP'); -} - -// Client Streaming -async function* handler2( - input: AsyncIterableIterator -): AsyncGenerator { - let chunks = Buffer.from(''); - for await(const chunk of input) { - chunks = Buffer.concat([chunks, chunk]); - } - return chunks; -} - -async function client2() { - console.log('CLIENT 2 START'); - async function* input() { - yield Buffer.from('hello'); - yield Buffer.from('world'); - } - const output = handler2(input()) as AsyncIterableIterator; - // Cannot use for..of for returned values - // Because the `return` is not maintained - let done: boolean | undefined = false; - while (!done) { - let value: Buffer; - ({ done, value } = await output.next()); - console.log(value.toString()); - } - console.log('CLIENT 2 STOP'); -} - -// Server streaming -async function* handler3( - _: AsyncIterableIterator -): AsyncGenerator { - // This handler doesn't care about the input - // It doesn't even bother processing it - yield Buffer.from('hello'); - yield Buffer.from('world'); - // Can we use the `return` to indicate an "early close"? - return Buffer.from('end'); - // It is possible to return `undefined` - // return; -} - -async function client3() { - console.log('CLIENT 3 START'); - // The RPC system can default `undefined` to be an empty async generator - const output = handler3((async function* () {})()) as AsyncIterableIterator; - while (true) { - const { done, value } = await output.next(); - if (Buffer.isBuffer(value)) { - console.log(value.toString()); - } else { - console.log('end with nothing'); - } - if (done) { - break; - } - } - console.log('CLIENT 3 STOP'); -} - -// Duplex streaming -// Pull-on both ends -async function *handler4( - input: AsyncIterableIterator -): AsyncGenerator { - // Note that - // the reason why we return `AsyncGenerator` - // Is because technically the user of this - // Can be used with `return()` and `throw()` - // But it is important to realise the types - // Can be more flexible - // We may wish to create our own types to be compatible - - // This concurrently consumes and concurrently produces - // The order is not sequenced - // How do we do this? - // Well something has to indicate consumption - // Something has to indicate production - // But they should be done in parallel - - // This is something that can be done - // by converting them to web streams (but that focuses on buffers) - // Alternatively by converting it to an event emitter? - // Or through rxjs... let's try that soon - - void (async () => { - // It can be expected that the input will end - // when the connection is stopped - // Or if abruptly we must consider the catching an exception - while (true) { - const { done, value } = await input.next(); - if (Buffer.isBuffer(value)) { - console.log('received', value.toString()); - } - if (done) { - console.log('received done'); - break; - } - } - })(); - - let counter = 0; - while (true) { - yield Buffer.from(counter.toString()); - counter++; - } - - // how do we know when to stop consuming? - // remember that once the connection stops - // we need to indicate that when we are finished - // remember that the thing should eventually be done - // otherwise we have a dangling promise - // that's kind of important - // wait we have an issue here - // how do we know when we are "finished"? - // or do we just `void` it? -} - -async function client4() { - console.log('CLIENT 4 START'); - async function *input() { - yield Buffer.from('hello'); - yield Buffer.from('world'); - return; - } - const output = handler4(input()) as AsyncIterableIterator; - console.log(await output.next()); - console.log(await output.next()); - console.log(await output.next()); - console.log(await output.next()); - console.log(await output.next()); - console.log(await output.next()); - - // if we want to "finish" the stream - // we can just stop consuming the `next()` - // But there's an issue here - console.log('CLIENT 4 STOP'); -} - -// How to "break" connection -async function* handler5( - input: AsyncIterableIterator -): AsyncGenerator { - while (true) { - let value, done; - try { - ({ value, done } = await input.next()); - } catch (e) { - console.log('SERVER GOT ERROR:', e.message); - break; - } - console.log('server received', value, done); - yield Buffer.from('GOT IT'); - if (done) { - console.log('server done'); - break; - } - } - return; -} - -async function client5() { - console.log('CLIENT 5 START'); - // In this scenario - async function* input() { - while (true) { - await sleep(100); - try { - yield Buffer.from('hello'); - } catch (e) { - yield; - throw e; - } - } - } - const inputG = input(); - const output = handler5(inputG as AsyncIterableIterator); - setTimeout(() => { - void inputG.throw(new Error('Connection Failed')); - }, 250); - while(true) { - const { done, value } = await output.next(); - console.log('client received', value); - if (done) { - break; - } - } - console.log('CLIENT 5 STOP'); -} - -// Convert to `push` - -// This is a push based system -// if you don't answer it -// let's see -const subject = new Subject(); - -subject.subscribe({ - next: (v) => console.log('PUSH:', v) -}); - -async function *handler6 ( - input: AsyncIterableIterator -): AsyncGenerator { - - // This is "done" asynchronously, while we pull from the stream - // How to do this in an asynchronus way? - const p = (async () => { - while (true) { - const { value, done } = await input.next(); - subject.next(value); - if (done) { - break; - } - } - })(); - - await sleep(100); - - yield Buffer.from('Hello'); - yield Buffer.from('World'); - // The stream is FINISHED - // but is the function call still completing? - // Consider what happens if that is the case - // We may want the function's lifecycle to be more complete - - // Await to finish this - // This is what allows you to capture any errors - // And the RPC system to throw it back up!!! - await p; - - // This sort of means that the OUTPUT stream isn't finished - // UNTIL you are finishign the INPUT stream - // This is a bit of a problem - // You can also send it earlier - // But if there's an exception in the processing... - - // WELL YEA... you'd need to keep the output stream open - // while you are consuming data - // otherwise you cannot signal that something failed - return; -} - -async function client6() { - console.log('CLIENT 6 START'); - // In this scenario - async function* input() { - yield Buffer.from('first'); - yield Buffer.from('second'); - return; - } - const output = handler6(input()); - while(true) { - const { done, value } = await output.next(); - console.log('client received', value); - if (done) { - break; - } - } - console.log('CLIENT 6 STOP'); -} - - -async function main() { - // await client1(); - // await client2(); - // await client3(); - // await client4(); - // await client5(); - await client6(); -} - -void main(); - -// We assume that the RPC wrapper would plumb the async generator data -// into the underlying web stream provided by the transport layer - -// The async generator `return` can be used to indicate and early -// finish to the the stream -// If `return;` is used, no last chunk is written -// If `return buf;` is used, then the buf is the last chunk to be written -// It also means the `value` could be `undefined` - -// It is possible to "force" a `return` to be applied on the outside -// this mean the `input` stream can be used -// Abort signal can also be used to indicate asynchronous cancellation -// But that is supposed to be used to cancel async operations -// Does the async generator for input stream also get a `ctx`? -// What about `throw`? Does it cancel the stream? - -// What about `ixjs`? Should this embed `ixjs`, so it can be more easily -// used? Technically ixjs works on the iterable, not on the generator -// It doesn't maintain the generator itself right? -// It would be nice if it was still a generator. - -// How to deal with metadata? For authentication... -// Is it part of the RPC system, leading and trailing metadata? -// Each message could have a metadata -// Does it depend on the RPC system itself? -// What about the transport layer? - -// Also if we enable generators -// we technically can communicate back -// that should be disallowed (since it doesn't make sense) -// perhaps we can use a different type -// Like `AsyncIterableIterator` instead of the same thing? -// It limits it to `next` -// Which is interesting - -// I think this is more correct -// You want to "take" an AsyncIterableIterator -// But the client side would get an AsyncIterableIterator -// But pass in an AsyncGenerator -// I think this makes more sense... -// async function *lol( -// x: AsyncIterableIterator -// ): AsyncGenerator { -// yield Buffer.from('hello'); -// return; -// } - diff --git a/test-gg.ts b/test-gg.ts deleted file mode 100644 index 90f3e7d88..000000000 --- a/test-gg.ts +++ /dev/null @@ -1,211 +0,0 @@ -import fc from 'fast-check'; -import type { ClaimIdEncoded, IdentityId, NodeId, ProviderId } from './src/ids'; -import { DB } from '@matrixai/db'; -import ACL from './src/acl/ACL'; -import GestaltGraph from './src/gestalts/GestaltGraph'; -import { IdInternal } from '@matrixai/id'; -import Logger, { LogLevel, StreamHandler, formatting } from '@matrixai/logger'; -import * as ids from './src/ids'; - -const nodeIdArb = fc.uint8Array({ minLength: 32, maxLength: 32 }).map( - IdInternal.create -) as fc.Arbitrary; - -// const nodeId = IdInternal.fromBuffer(Buffer.allocUnsafe(32)); - -async function main() { - - // Top level - // but we cannot raise the bottom level - // we can only hide levels - // or filter - // You could also set a filter - - const logger = new Logger( - 'TEST', - LogLevel.DEBUG, - [ - new StreamHandler( - formatting.format`${formatting.level}:${formatting.keys}:${formatting.msg}` - ), - ] - ); - - const dbLogger = logger.getChild('DB'); - dbLogger.setLevel(LogLevel.INFO); - - const db = await DB.createDB({ - dbPath: 'tmp/db', - logger: dbLogger, - fresh: true, - }); - - const aclLogger = logger.getChild('ACL'); - aclLogger.setLevel(LogLevel.INFO); - - const acl = await ACL.createACL({ - db, - logger: aclLogger, - }); - - - const ggLogger = logger.getChild('GestaltGraph'); - ggLogger.setLevel(LogLevel.DEBUG); - - const gg = await GestaltGraph.createGestaltGraph({ - db, - acl, - logger: ggLogger, - }); - - const nodeId1 = fc.sample(nodeIdArb, 1)[0]; - - - await gg.setNode({ - nodeId: nodeId1 - }); - - const nodeId2 = fc.sample(nodeIdArb, 1)[0]; - - await gg.setNode({ - nodeId: nodeId2, - }); - - const nodeId3 = fc.sample(nodeIdArb, 1)[0]; - - await gg.setNode({ - nodeId: nodeId3, - }); - - const nodeId4 = fc.sample(nodeIdArb, 1)[0]; - - await gg.setNode({ - nodeId: nodeId4, - }); - - const nodeId5 = fc.sample(nodeIdArb, 1)[0]; - - await gg.setNode({ - nodeId: nodeId5, - }); - - await gg.setIdentity({ - providerId: '123' as ProviderId, - identityId: 'abc' as IdentityId - }); - - await gg.linkNodeAndNode( - { - nodeId: nodeId1 - }, - { - nodeId: nodeId2 - }, - { - meta: {}, - claim: { - payload: { - iss: ids.encodeNodeId(nodeId1), - sub: ids.encodeNodeId(nodeId2), - jti: 'asfoiuadf' as ClaimIdEncoded, - iat: 123, - nbf: 123, - seq: 123, - prevClaimId: null, - prevDigest: null - }, - signatures: [] - } - } - ); - - await gg.linkNodeAndNode( - { - nodeId: nodeId1 - }, - { - nodeId: nodeId3 - }, - { - meta: {}, - claim: { - payload: { - iss: ids.encodeNodeId(nodeId1), - sub: ids.encodeNodeId(nodeId3), - jti: 'asfoiuadf' as ClaimIdEncoded, - iat: 123, - nbf: 123, - seq: 123, - prevClaimId: null, - prevDigest: null - }, - signatures: [] - } - } - ); - - await gg.linkNodeAndNode( - { - nodeId: nodeId2 - }, - { - nodeId: nodeId3 - }, - { - meta: {}, - claim: { - payload: { - iss: ids.encodeNodeId(nodeId2), - sub: ids.encodeNodeId(nodeId3), - jti: 'asfoiuadf' as ClaimIdEncoded, - iat: 123, - nbf: 123, - seq: 123, - prevClaimId: null, - prevDigest: null - }, - signatures: [] - } - } - ); - - // await gg.linkNodeAndNode( - // { - // nodeId: nodeId1 - // }, - // { - // nodeId: nodeId2 - // }, - // { - // type: 'node', - // meta: {}, - // claim: { - // payload: { - // jti: 's8d9sf98s7fd98sfd7' as ClaimIdEncoded, - // iss: ids.encodeNodeId(nodeId1), - // sub: ids.encodeNodeId(nodeId2), - // iat: 123, - // nbf: 123, - // seq: 123, - // prevClaimId: null, - // prevDigest: null - // }, - // signatures: [] - // } - // } - // ); - - console.log(await db.dump(gg.dbMatrixPath, true)); - // console.log(await db.dump(gg.dbNodesPath, true)); - // console.log(await db.dump(gg.dbLinksPath, true)); - - for await (const gestalt of gg.getGestalts()) { - console.group('Gestalt'); - console.dir(gestalt, { depth: null }); - // console.log('nodes', gestalt.nodes); - console.groupEnd(); - } - -} - -main(); diff --git a/test-hashing.ts b/test-hashing.ts deleted file mode 100644 index cc8e4eed7..000000000 --- a/test-hashing.ts +++ /dev/null @@ -1,37 +0,0 @@ -import * as hash from './src/keys/utils/hash'; -import * as hashing from './src/tokens/utils'; - -async function main () { - - // thisis what it takes to do it - - const digest = hash.sha256(Buffer.from('hello world')); - console.log(hashing.sha256MultiHash(digest)); - - - - // const encodeR = await hashing.sha256M.encode(Buffer.from('abc')); - // const digestR = await hashing.sha256M.digest(Buffer.from('abc')); - - // console.log(encodeR.byteLength); - // console.log(encodeR); - - // console.log(digestR); - - // // so remember - // // that upon hashing, you have a multihash digest - - // // this is the actual byte reprentation - // // the remaining stuff still needs to be "multibase" encoded - // console.log(digestR.bytes); - - - // // so therefore - // // BASEENCODING + MULTIHASH is exactly what you want - - - - -} - -main(); diff --git a/test-muxrpc-client.ts b/test-muxrpc-client.ts deleted file mode 100644 index 15f86d3de..000000000 --- a/test-muxrpc-client.ts +++ /dev/null @@ -1,176 +0,0 @@ -import MRPC from 'muxrpc'; -import pull from 'pull-stream'; -import toPull from 'stream-to-pull-stream'; -import net from 'net'; -import { sleep } from './src/utils'; - -const manifest = { - hello: 'async', - // another: 'async', - stuff: 'source', - - sink: 'sink', - - duplex: 'duplex', -}; - -// Client needs the remote manifest, it can pass a local manifest, and then a local API - -// Remote manifest, local manifest, codec -// Local API, Permissions, ID -const client = MRPC(manifest, null)(null, null, 'CLIENT'); - -console.log(client); - -const stream = toPull.duplex(net.connect(8080)); - -// const onClose = () => { -// console.log('closed connection to muxrpc server'); -// }; - - -const mStream = client.createStream(); - -// This also takes a duplex socket, and converts it to a "pull-stream" -pull(stream, mStream, stream); - -// So now that the client is composed there -// Also interestingly... notice that the TCP socket above does its own establishment -// The RPC hence is "transport" agnostic because it works on anything that is a duplex stream -// That's pretty good - -client.hello('world', 100, (err, data) => { - if (err != null) throw err; - console.log('HELLO call1', data); -}); - -// Oh cool, it does support promises on the client side - -// client.hello('world', 50, (err, data) => { -// if (err != null) throw err; -// console.log('HELLO call2', data); -// }); - -// client.hello('world', 10, (err, data) => { -// if (err != null) throw err; -// console.log('HELLO call3', data); -// }); - -// Yep there's a muxing of the RPC calls here -// This makes alot of sense - -// No deadline... it's not finished -// Ok then there's a failure, we have 1 stream per rpc -// client.another('another').then((data) => { -// console.log('ANOTHER', data); -// }); - -console.log('SENT all hellos over'); - -// Now if you want to do a stream, it seems `pull.values` ultimately returns some sort of stream object - -// const s = client.stuff(); -// // Yea this becomes a "source" stream -// // So it is infact the same type -// console.log('stuff stream', s); - -// pull(s, pull.drain(console.log)); - -// So how does this actually "mux" the RPC calls? -// Can they be concurrent? -// I think the muxing still has to end up -// interleaving the data... -// So it's still in sequence -// But the order can be different depending on the situation - - -// This is a sink -// we need to feed datat to the seek -// How do we know when things are done...? -// client.sink( -// pull.values(['hello', 'world']), -// (e, v) => { -// console.log('got it', v); -// } -// ); - -// const sink = client.sink(); - -// pull( -// pull.values(['hello', 'world']), -// sink -// ); - -// When a "stream" is opened here -// it prevents the process from shutting down -// That's kind of bad -// We don't really want to keep the process open - -// const duplex = client.duplex(); - -// console.log('DUPLEX', duplex); - -// pull( -// pull.values([1, 2, 3, 'end']), -// duplex, -// pull.drain(console.log) -// ); - -// Nothing is "ending" the stream -// that's the problem - -// console.log('YO'); - - -// The entire MUXRPC object is an event emitter - -// console.log(client.id); - -console.log(mStream); - -// This is also asynchronous -// It ends up closing a "ws" -// Which usese `initStream` -// YOU HAVE TO SUPPLY A CALLBACK -// client.end(); - -console.log('is open', mStream.isOpen()); - -// I think this is actually wht is perfomring a remote call -// whatever... -// console.log('remote call', mStream.remoteCall.toString()); - -mStream.close(() => { - console.log('CLOSING MUXRPC STREAM'); - - // Closing the stream also closes the client - // The client can create a stream... - // That's really strange - // Ok but the client is then closed too? - console.log(client.closed); -}); - - - -// client.close(() => { -// console.log('ClOSED'); -// }); - -// Remember TCP sockets are duplex streams -// So they are already duplex concurrently -// They are also event emitters at the same time - -// But dgram sockets are EventEmitters -// They are not Duplex stream -// They are not streams at all -// Which makes sense -// But as an event emitter, that makes them concurrent in both directions too -// Messages could be sent there and also received -// All UDP datagrams can be sent to alternative destinations, even if bound to the same socket - -// So this is very interesting - -// I'm still seeing a problem. -// How does the handlers get context of the RPC call? And how does it get access to the remote side's manifest? - - diff --git a/test-muxrpc-server.ts b/test-muxrpc-server.ts deleted file mode 100644 index f24adee77..000000000 --- a/test-muxrpc-server.ts +++ /dev/null @@ -1,201 +0,0 @@ -import MRPC from 'muxrpc'; -import pull from 'pull-stream'; -import toPull from 'stream-to-pull-stream'; -import net from 'net'; -import pushable from 'pull-pushable'; -import { sleep } from './src/utils'; - -// "dynamic" manifest - -const manifest = { - hello: 'async', - // another: 'async', - stuff: 'source', - - // There's also `sink` streams - sink: 'sink', - - duplex: 'duplex', -}; - -// actual handlers, notice no promise support -const api = { - async hello(name, time, cb) { - // How are we supposed to know who contacted us? - // This is kind of stupid - - await sleep(time); - cb(null, 'hello' + name + '!'); - }, - // async another (name) { - // return 'hello' + name; - // }, - stuff() { - const s = pull.values([1,2,3,4,5]); - - // Yes, a "source" is a function - // This is a function - // Remember the "stream" is already mutated - // But this ends up returning a function - // This function basically starts the source - // The `cb` is used to read the data - // The cb gets called and it receives the data from the source - // That's where things get a bit complicated - // So the type of this is complex - // console.log('is this a source', s.toString()); - - // IT RETURNS THE SOURCE STREAM - return s; - }, - sink() { - // Cause it is a sink, it only takes data - // It does not give you back any confirmation back? - // IT RETURNS the sink stream... think about that - return pull.collect( - (e, arr) => { - console.log('SUNK', arr); - } - ); - }, - duplex() { - - // This needs to return a source and sink together - // Parameters are still passable into it at the beginning - // Sort of like how our duplex streams are structured - // We are able to pass it the initial message - - // The source cannot be `pull.values` - // Because it ends the stream - // That seems kind of wrong - - const p = pushable(); - // for (let i = 0; i < 5; i++) { - // p.push(i); - // } - // But this seems difficult to use - // How do we "consume" things, and then simultaneously - // push things to the source? - - return { - source: p, - sink: pull.drain((value) => { - // Wait I'm confused, how does this mean it ends? - // How do I know when this is ended? - - // If the `p` doesn't end - // We end up with a problem - - if (value === 'end') { - p.end(); - } - - p.push(value + 1); - }) - }; - }, -}; - -// Remote manifest, local manifest, codec -// Local API, Permissions, ID -const server = MRPC(null, manifest)(api, null, 'SERVER'); - -console.log(server); - -const muxStream = server.createStream(); - -net.createServer(socket => { - - console.log('NEW CONNECTION'); - - // The socket is converted to a duplex pull-stream - // Stream to Pull Stream is a conversion utility - // Converts NodeJS streams (classic-stream and new-stream) into a pull-stream - // It returns an object structure { source, sink } - // It is source(stream) and sink(stream, cb) - // The source will attach handlers on the `data`, `end`, `close`, and `error` events - // It pushes the stream data into the a buffers list. - // It also switches the stream to `rsume`, so it may be paused - // If the length exists, and the stream is pausable, then it will end up calling `stream.pause()` - // So upon handling a single `data` event, it will end up pausing the stream - // The idea being is that the buffer will have data - // Ok I get the idea... - // If the buffer still has data even after calling `drain`, then that means there's already data queued up - // That's why it pauses the stream - // If a stream is paused, data events will not be emitted - // The drawin runs a loop as long as there's data in the queue or ended, and the cbs is filled - // The dbs are callbacks, it shifts one of them, and applies it to a chunk of data - // Then it will check if the length is empty, and it is paused, it will then unpause it, and resume the stream - // On the write side, it is attaching handlers to the stream as well - // This time on close, finish and error. - // On the next tick ,it is then calling the `read` function - // Because it has to "read" a data from the source - // This callback then is given the data - // The data is written with `stream.write(data)` - // Anyway the point is that these is object of 2 things - // { source, sink } - const stream = toPull.duplex(socket); - - // This connects the output (source) of the stream to the muxrpc stream - // And it connects the output of the muxRPC stream to the net stream - // This is variadic function, it can take multiple things that are streams - // It pulls from left to right - // So the stream source is pulled into the muxrpc stream and then pulled into the stream again - // NET SOCKET SOURCE -> MUXRPC -> NET SOCKET SINK - // The duplex socket is being broken apart into 2 separate streams - // Then they get composed with the input and output of the muxrpc stream - // And therefore muxrpc is capable multiplexing its internal messages on top of the base socket streams - - - pull(stream, muxStream, stream); - - // every time a new connection occurs - // we have to do something different... - // I think... otherwise the streams get screwed up - // But I'm not entirely sure - // How do we do this over and over - // We have to "disconnect" - - socket.on('close', () => { - console.log('SOCKET CLOSED'); - // muxStream.close(); - }); - - socket.on('end', () => { - console.log('SOCKET ENDED'); - }); - - -}).listen(8080); - -// In a way, this pull-stream is kind of similar to the idea behind rxjs -// But it's very stream heavy, lack of typescript... etc -// Also being a pull stream, it only pulls from the source when it needs to -// I'm not sure what actually triggers it, it seems the source is given a "read" function -// So when the sink is ready, it ends up calling the `read` function - -// The looper is used for "asynchronous" looping -// Because it uses the `next` call to cycle -// This is necessary in asynchronous callbacks -// However this is not necessary if we are using async await syntax -// In a callback situation you cannot just use `while() { ... }` loop -// But you casn when using async await -// I'm confused about the `function (next) { .. }` cause the looper is not passing -// anything into the `next` parameter, so that doesn't make sense -// Right it is using the 3.0.0 of looper, which had a different design -// Ok so the point is, it's a process next tick, with an asynchronous infinite loop -// The loop repeatedly calls `read` upon finishing the `write` callback -// And it will do this until the read callback is ended -// Or if the output stream itself is ended -// This is what gives it a natural form of backpressure -// It will only "pull" something as fast as the sink can take it -// Since it only triggers a pull, when the sink is drained - -// An async iterator/iterator can do the same thing -// And thus one can "next" it as fast as the sink can read - -// Similarly a push stream would be subscribing, but it's possible to backpressure this with -// buffers or with dropping systems... naturally buffers should be used, and the application can drop data - -// This is an old structure, and I prefer modern JS with functional concepts - -// Now that the server is listening, we can create the client diff --git a/test-subject.ts b/test-subject.ts deleted file mode 100644 index dda09ede9..000000000 --- a/test-subject.ts +++ /dev/null @@ -1,20 +0,0 @@ -import { Subject } from 'rxjs'; - -const subject = new Subject(); - -// These are dropped, nobody is listening -subject.next(1); - -subject.subscribe({ - next: (v) => console.log(`observerA: ${v}`) -}); - -subject.next(2); - -// B only gets 3 and 4 -subject.subscribe({ - next: (v) => console.log(`observerB: ${v}`) -}); - -subject.next(3); -subject.next(4); From 1d88f5d343925b10dfbcdf24998d10d695ee983e Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Mon, 13 Feb 2023 17:02:35 +1100 Subject: [PATCH 43/44] fix: extracting middleware to its own files [ci skip] --- src/RPC/RPCClient.ts | 3 +- src/RPC/RPCServer.ts | 3 +- src/RPC/middleware.ts | 158 +++++++++++++ src/RPC/utils.ts | 220 +----------------- src/clientRPC/authenticationMiddleware.ts | 122 ++++++++++ src/clientRPC/utils.ts | 122 +--------- tests/RPC/RPCClient.test.ts | 6 +- tests/RPC/RPCServer.test.ts | 7 +- tests/RPC/middleware.test.ts | 96 ++++++++ tests/RPC/utils.test.ts | 124 +--------- .../authenticationMiddleware.test.ts | 154 ++++++++++++ tests/clientRPC/handlers/agentUnlock.test.ts | 13 +- 12 files changed, 562 insertions(+), 466 deletions(-) create mode 100644 src/RPC/middleware.ts create mode 100644 src/clientRPC/authenticationMiddleware.ts create mode 100644 tests/RPC/middleware.test.ts create mode 100644 tests/clientRPC/authenticationMiddleware.test.ts diff --git a/src/RPC/RPCClient.ts b/src/RPC/RPCClient.ts index 8c01f91e8..f0a684a02 100644 --- a/src/RPC/RPCClient.ts +++ b/src/RPC/RPCClient.ts @@ -15,6 +15,7 @@ import type { } from './types'; import { CreateDestroy, ready } from '@matrixai/async-init/dist/CreateDestroy'; import Logger from '@matrixai/logger'; +import * as middlewareUtils from './middleware'; import * as rpcErrors from './errors'; import * as rpcUtils from './utils'; import { @@ -30,7 +31,7 @@ class RPCClient { static async createRPCClient({ manifest, streamPairCreateCallback, - middleware = rpcUtils.defaultClientMiddlewareWrapper(), + middleware = middlewareUtils.defaultClientMiddlewareWrapper(), logger = new Logger(this.name), }: { manifest: M; diff --git a/src/RPC/RPCServer.ts b/src/RPC/RPCServer.ts index a85afcd14..d33d8f053 100644 --- a/src/RPC/RPCServer.ts +++ b/src/RPC/RPCServer.ts @@ -29,6 +29,7 @@ import { } from './handlers'; import * as rpcUtils from './utils'; import * as rpcErrors from './errors'; +import * as middlewareUtils from './middleware'; import { never } from '../utils/utils'; import { sysexits } from '../errors'; @@ -37,7 +38,7 @@ interface RPCServer extends CreateDestroy {} class RPCServer { static async createRPCServer({ manifest, - middleware = rpcUtils.defaultServerMiddlewareWrapper(), + middleware = middlewareUtils.defaultServerMiddlewareWrapper(), sensitive = false, logger = new Logger(this.name), }: { diff --git a/src/RPC/middleware.ts b/src/RPC/middleware.ts new file mode 100644 index 000000000..0f3150d83 --- /dev/null +++ b/src/RPC/middleware.ts @@ -0,0 +1,158 @@ +import type { + JsonRpcMessage, + JsonRpcRequest, + JsonRpcResponse, + JsonRpcResponseResult, + MiddlewareFactory, +} from './types'; +import { TransformStream } from 'stream/web'; +import * as rpcErrors from './errors'; +import * as rpcUtils from './utils'; +const jsonStreamParsers = require('@streamparser/json'); + +function binaryToJsonMessageStream( + messageParser: (message: unknown) => T, + byteLimit: number = 1024 * 1024, + firstMessage?: T, +) { + const parser = new jsonStreamParsers.JSONParser({ + separator: '', + paths: ['$'], + }); + let bytesWritten: number = 0; + + return new TransformStream({ + start: (controller) => { + if (firstMessage != null) controller.enqueue(firstMessage); + parser.onValue = (value) => { + const jsonMessage = messageParser(value.value); + controller.enqueue(jsonMessage); + bytesWritten = 0; + }; + }, + transform: (chunk) => { + try { + bytesWritten += chunk.byteLength; + parser.write(chunk); + } catch (e) { + throw new rpcErrors.ErrorRpcParse(undefined, { cause: e }); + } + if (bytesWritten > byteLimit) { + throw new rpcErrors.ErrorRpcMessageLength(); + } + }, + }); +} + +function jsonMessageToBinaryStream() { + return new TransformStream({ + transform: (chunk, controller) => { + controller.enqueue(Buffer.from(JSON.stringify(chunk))); + }, + }); +} + +const defaultMiddleware: MiddlewareFactory< + JsonRpcRequest, + JsonRpcRequest, + JsonRpcResponse, + JsonRpcResponse +> = () => { + return { + forward: new TransformStream(), + reverse: new TransformStream(), + }; +}; + +const defaultServerMiddlewareWrapper = ( + middleware: MiddlewareFactory< + JsonRpcRequest, + JsonRpcRequest, + JsonRpcResponse, + JsonRpcResponse + > = defaultMiddleware, +) => { + return (header: JsonRpcRequest) => { + const inputTransformStream = binaryToJsonMessageStream( + rpcUtils.parseJsonRpcRequest, + undefined, + header, + ); + const outputTransformStream = new TransformStream< + JsonRpcResponseResult, + JsonRpcResponseResult + >(); + + const middleMiddleware = middleware(header); + + const forwardReadable = inputTransformStream.readable.pipeThrough( + middleMiddleware.forward, + ); // Usual middleware here + const reverseReadable = outputTransformStream.readable + .pipeThrough(middleMiddleware.reverse) // Usual middleware here + .pipeThrough(jsonMessageToBinaryStream()); + + return { + forward: { + readable: forwardReadable, + writable: inputTransformStream.writable, + }, + reverse: { + readable: reverseReadable, + writable: outputTransformStream.writable, + }, + }; + }; +}; + +const defaultClientMiddlewareWrapper = ( + middleware: MiddlewareFactory< + JsonRpcRequest, + JsonRpcRequest, + JsonRpcResponse, + JsonRpcResponse + > = defaultMiddleware, +): MiddlewareFactory< + Uint8Array, + JsonRpcRequest, + JsonRpcResponse, + Uint8Array +> => { + return () => { + const outputTransformStream = binaryToJsonMessageStream( + rpcUtils.parseJsonRpcResponse, + undefined, + ); + const inputTransformStream = new TransformStream< + JsonRpcRequest, + JsonRpcRequest + >(); + + const middleMiddleware = middleware(); + const forwardReadable = inputTransformStream.readable + .pipeThrough(middleMiddleware.forward) // Usual middleware here + .pipeThrough(jsonMessageToBinaryStream()); + const reverseReadable = outputTransformStream.readable.pipeThrough( + middleMiddleware.reverse, + ); // Usual middleware here + + return { + forward: { + readable: forwardReadable, + writable: inputTransformStream.writable, + }, + reverse: { + readable: reverseReadable, + writable: outputTransformStream.writable, + }, + }; + }; +}; + +export { + binaryToJsonMessageStream, + jsonMessageToBinaryStream, + defaultMiddleware, + defaultServerMiddlewareWrapper, + defaultClientMiddlewareWrapper, +}; diff --git a/src/RPC/utils.ts b/src/RPC/utils.ts index 4a069e3f1..1d5da7f6f 100644 --- a/src/RPC/utils.ts +++ b/src/RPC/utils.ts @@ -1,68 +1,25 @@ import type { + ClientManifest, + HandlerType, JsonRpcError, JsonRpcMessage, - JsonRpcRequestNotification, + JsonRpcRequest, JsonRpcRequestMessage, + JsonRpcRequestNotification, + JsonRpcResponse, JsonRpcResponseError, JsonRpcResponseResult, - JsonRpcRequest, - JsonRpcResponse, - MiddlewareFactory, - HandlerType, - ClientManifest, } from 'RPC/types'; import type { JSONValue } from '../types'; import { TransformStream } from 'stream/web'; import { AbstractError } from '@matrixai/errors'; import * as rpcErrors from './errors'; import * as utils from '../utils'; +import { promise } from '../utils'; import * as validationErrors from '../validation/errors'; import * as errors from '../errors'; -import { promise } from '../utils'; const jsonStreamParsers = require('@streamparser/json'); -function binaryToJsonMessageStream( - messageParser: (message: unknown) => T, - byteLimit: number = 1024 * 1024, - firstMessage?: T, -) { - const parser = new jsonStreamParsers.JSONParser({ - separator: '', - paths: ['$'], - }); - let bytesWritten: number = 0; - - return new TransformStream({ - start: (controller) => { - if (firstMessage != null) controller.enqueue(firstMessage); - parser.onValue = (value) => { - const jsonMessage = messageParser(value.value); - controller.enqueue(jsonMessage); - bytesWritten = 0; - }; - }, - transform: (chunk) => { - try { - bytesWritten += chunk.byteLength; - parser.write(chunk); - } catch (e) { - throw new rpcErrors.ErrorRpcParse(undefined, { cause: e }); - } - if (bytesWritten > byteLimit) { - throw new rpcErrors.ErrorRpcMessageLength(); - } - }, - }); -} - -function jsonMessageToBinaryStream() { - return new TransformStream({ - transform: (chunk, controller) => { - controller.enqueue(Buffer.from(JSON.stringify(chunk))); - }, - }); -} - function parseJsonRpcRequest( message: unknown, ): JsonRpcRequest { @@ -382,12 +339,11 @@ function reviver(key: string, value: any): any { // Root key will be '' // Reaching here means the root JSON value is not a valid exception // Therefore ErrorPolykeyUnknown is only ever returned at the top-level - const error = new errors.ErrorPolykeyUnknown('Unknown error JSON', { + return new errors.ErrorPolykeyUnknown('Unknown error JSON', { data: { json: value, }, }); - return error; } else { return value; } @@ -427,8 +383,7 @@ function clientOutputTransformStream() { } function isReturnableError(e: Error): boolean { - if (e instanceof rpcErrors.ErrorRpcNoMessageError) return false; - return true; + return !(e instanceof rpcErrors.ErrorRpcNoMessageError); } class RPCErrorEvent extends Event { @@ -448,61 +403,6 @@ class RPCErrorEvent extends Event { } } -const controllerTransformationFactory = () => { - const controllerProm = promise>(); - - class ControllerTransform implements Transformer { - start: TransformerStartCallback = async (controller) => { - // @ts-ignore: type mismatch oddity - controllerProm.resolveP(controller); - }; - - transform: TransformerTransformCallback = async ( - chunk, - controller, - ) => { - controller.enqueue(chunk); - }; - } - - class ControllerTransformStream extends TransformStream { - constructor() { - super(new ControllerTransform()); - } - } - return { - controllerP: controllerProm.p, - controllerTransformStream: new ControllerTransformStream(), - }; -}; - -function queueMergingTransformStream(messageQueue: Array) { - return new TransformStream({ - start: (controller) => { - while (true) { - const value = messageQueue.shift(); - if (value == null) break; - controller.enqueue(value); - } - }, - transform: (chunk, controller) => { - while (true) { - const value = messageQueue.shift(); - if (value == null) break; - controller.enqueue(value); - } - controller.enqueue(chunk); - }, - flush: (controller) => { - while (true) { - const value = messageQueue.shift(); - if (value == null) break; - controller.enqueue(value); - } - }, - }); -} - function extractFirstMessageTransform( messageParser: (message: unknown) => T, byteLimit: number = 1024 * 1024, @@ -574,106 +474,7 @@ function getHandlerTypes( return out; } -const defaultMiddleware: MiddlewareFactory< - JsonRpcRequest, - JsonRpcRequest, - JsonRpcResponse, - JsonRpcResponse -> = () => { - return { - forward: new TransformStream(), - reverse: new TransformStream(), - }; -}; - -const defaultServerMiddlewareWrapper = ( - middleware: MiddlewareFactory< - JsonRpcRequest, - JsonRpcRequest, - JsonRpcResponse, - JsonRpcResponse - > = defaultMiddleware, -) => { - return (header: JsonRpcRequest) => { - const inputTransformStream = binaryToJsonMessageStream( - parseJsonRpcRequest, - undefined, - header, - ); - const outputTransformStream = new TransformStream< - JsonRpcResponseResult, - JsonRpcResponseResult - >(); - - const middleMiddleware = middleware(header); - - const forwardReadable = inputTransformStream.readable.pipeThrough( - middleMiddleware.forward, - ); // Usual middleware here - const reverseReadable = outputTransformStream.readable - .pipeThrough(middleMiddleware.reverse) // Usual middleware here - .pipeThrough(jsonMessageToBinaryStream()); - - return { - forward: { - readable: forwardReadable, - writable: inputTransformStream.writable, - }, - reverse: { - readable: reverseReadable, - writable: outputTransformStream.writable, - }, - }; - }; -}; - -const defaultClientMiddlewareWrapper = ( - middleware: MiddlewareFactory< - JsonRpcRequest, - JsonRpcRequest, - JsonRpcResponse, - JsonRpcResponse - > = defaultMiddleware, -): MiddlewareFactory< - Uint8Array, - JsonRpcRequest, - JsonRpcResponse, - Uint8Array -> => { - return () => { - const outputTransformStream = binaryToJsonMessageStream( - parseJsonRpcResponse, - undefined, - ); - const inputTransformStream = new TransformStream< - JsonRpcRequest, - JsonRpcRequest - >(); - - const middleMiddleware = middleware(); - const forwardReadable = inputTransformStream.readable - .pipeThrough(middleMiddleware.forward) // Usual middleware here - .pipeThrough(jsonMessageToBinaryStream()); - const reverseReadable = outputTransformStream.readable.pipeThrough( - middleMiddleware.reverse, - ); // Usual middleware here - - return { - forward: { - readable: forwardReadable, - writable: inputTransformStream.writable, - }, - reverse: { - readable: reverseReadable, - writable: outputTransformStream.writable, - }, - }; - }; -}; - export { - binaryToJsonMessageStream, - jsonMessageToBinaryStream, parseJsonRpcRequest, parseJsonRpcRequestMessage, parseJsonRpcRequestNotification, @@ -687,11 +488,6 @@ export { clientOutputTransformStream, isReturnableError, RPCErrorEvent, - controllerTransformationFactory, - queueMergingTransformStream, extractFirstMessageTransform, getHandlerTypes, - defaultMiddleware, - defaultServerMiddlewareWrapper, - defaultClientMiddlewareWrapper, }; diff --git a/src/clientRPC/authenticationMiddleware.ts b/src/clientRPC/authenticationMiddleware.ts new file mode 100644 index 000000000..ddf89ac4d --- /dev/null +++ b/src/clientRPC/authenticationMiddleware.ts @@ -0,0 +1,122 @@ +import type { + JsonRpcRequest, + JsonRpcResponse, + MiddlewareFactory, +} from '../RPC/types'; +import type { RPCRequestParams, RPCResponseResult } from './types'; +import type { Session } from '../sessions'; +import type SessionManager from '../sessions/SessionManager'; +import type KeyRing from '../keys/KeyRing'; +import { TransformStream } from 'stream/web'; +import { authenticate, decodeAuth } from './utils'; +import * as utils from '../utils'; + +function authenticationMiddlewareServer( + sessionManager: SessionManager, + keyRing: KeyRing, +): MiddlewareFactory< + JsonRpcRequest, + JsonRpcRequest, + JsonRpcResponse, + JsonRpcResponse +> { + return () => { + let forwardFirst = true; + let reverseController; + let outgoingToken: string | null = null; + return { + forward: new TransformStream< + JsonRpcRequest, + JsonRpcRequest + >({ + transform: async (chunk, controller) => { + if (forwardFirst) { + try { + outgoingToken = await authenticate( + sessionManager, + keyRing, + chunk, + ); + } catch (e) { + controller.terminate(); + reverseController.terminate(); + return; + } + } + forwardFirst = false; + controller.enqueue(chunk); + }, + }), + reverse: new TransformStream({ + start: (controller) => { + reverseController = controller; + }, + transform: (chunk, controller) => { + // Add the outgoing metadata to the next message. + if (outgoingToken != null && 'result' in chunk) { + if (chunk.result.metadata == null) { + chunk.result.metadata = { + Authorization: '', + }; + } + chunk.result.metadata.Authorization = outgoingToken; + outgoingToken = null; + } + controller.enqueue(chunk); + }, + }), + }; + }; +} + +function authenticationMiddlewareClient( + session: Session, +): MiddlewareFactory< + JsonRpcRequest, + JsonRpcRequest, + JsonRpcResponse, + JsonRpcResponse +> { + return () => { + let forwardFirst = true; + return { + forward: new TransformStream< + JsonRpcRequest, + JsonRpcRequest + >({ + transform: async (chunk, controller) => { + if (forwardFirst) { + if (chunk.params == null) utils.never(); + if (chunk.params.metadata?.Authorization == null) { + const token = await session.readToken(); + if (token != null) { + if (chunk.params.metadata == null) { + chunk.params.metadata = { + Authorization: '', + }; + } + chunk.params.metadata.Authorization = `Bearer ${token}`; + } + } + } + forwardFirst = false; + controller.enqueue(chunk); + }, + }), + reverse: new TransformStream< + JsonRpcResponse, + JsonRpcResponse + >({ + transform: async (chunk, controller) => { + controller.enqueue(chunk); + if (!('result' in chunk)) return; + const token = decodeAuth(chunk.result); + if (token == null) return; + await session.writeToken(token); + }, + }), + }; + }; +} + +export { authenticationMiddlewareServer, authenticationMiddlewareClient }; diff --git a/src/clientRPC/utils.ts b/src/clientRPC/utils.ts index 2aaf8df11..9b280d77e 100644 --- a/src/clientRPC/utils.ts +++ b/src/clientRPC/utils.ts @@ -1,13 +1,8 @@ import type { SessionToken } from '../sessions/types'; import type KeyRing from '../keys/KeyRing'; import type SessionManager from '../sessions/SessionManager'; -import type { Session } from '../sessions'; -import type { RPCResponseResult, RPCRequestParams } from './types'; -import type { - JsonRpcRequest, - JsonRpcResponse, - MiddlewareFactory, -} from '../RPC/types'; +import type { RPCRequestParams } from './types'; +import type { JsonRpcRequest } from '../RPC/types'; import type { ReadableWritablePair } from 'stream/web'; import type Logger from '@matrixai/logger'; import type { ConnectionInfo, Host, Port } from '../network/types'; @@ -16,10 +11,9 @@ import type { TLSSocket } from 'tls'; import type { Server } from 'https'; import type net from 'net'; import type https from 'https'; -import { ReadableStream, TransformStream, WritableStream } from 'stream/web'; +import { ReadableStream, WritableStream } from 'stream/web'; import WebSocket, { WebSocketServer } from 'ws'; import * as clientErrors from '../client/errors'; -import * as utils from '../utils'; import { promise } from '../utils'; async function authenticate( @@ -71,114 +65,6 @@ function encodeAuthFromPassword(password: string): string { return `Basic ${encoded}`; } -function authenticationMiddlewareServer( - sessionManager: SessionManager, - keyRing: KeyRing, -): MiddlewareFactory< - JsonRpcRequest, - JsonRpcRequest, - JsonRpcResponse, - JsonRpcResponse -> { - return () => { - let forwardFirst = true; - let reverseController; - let outgoingToken: string | null = null; - return { - forward: new TransformStream< - JsonRpcRequest, - JsonRpcRequest - >({ - transform: async (chunk, controller) => { - if (forwardFirst) { - try { - outgoingToken = await authenticate( - sessionManager, - keyRing, - chunk, - ); - } catch (e) { - controller.terminate(); - reverseController.terminate(); - return; - } - } - forwardFirst = false; - controller.enqueue(chunk); - }, - }), - reverse: new TransformStream({ - start: (controller) => { - reverseController = controller; - }, - transform: (chunk, controller) => { - // Add the outgoing metadata to the next message. - if (outgoingToken != null && 'result' in chunk) { - if (chunk.result.metadata == null) { - chunk.result.metadata = { - Authorization: '', - }; - } - chunk.result.metadata.Authorization = outgoingToken; - outgoingToken = null; - } - controller.enqueue(chunk); - }, - }), - }; - }; -} - -function authenticationMiddlewareClient( - session: Session, -): MiddlewareFactory< - JsonRpcRequest, - JsonRpcRequest, - JsonRpcResponse, - JsonRpcResponse -> { - return () => { - let forwardFirst = true; - return { - forward: new TransformStream< - JsonRpcRequest, - JsonRpcRequest - >({ - transform: async (chunk, controller) => { - if (forwardFirst) { - if (chunk.params == null) utils.never(); - if (chunk.params.metadata?.Authorization == null) { - const token = await session.readToken(); - if (token != null) { - if (chunk.params.metadata == null) { - chunk.params.metadata = { - Authorization: '', - }; - } - chunk.params.metadata.Authorization = `Bearer ${token}`; - } - } - } - forwardFirst = false; - controller.enqueue(chunk); - }, - }), - reverse: new TransformStream< - JsonRpcResponse, - JsonRpcResponse - >({ - transform: async (chunk, controller) => { - controller.enqueue(chunk); - if (!('result' in chunk)) return; - const token = decodeAuth(chunk.result); - if (token == null) return; - await session.writeToken(token); - }, - }), - }; - }; -} - function readableFromWebSocket( ws: WebSocket, logger: Logger, @@ -372,8 +258,6 @@ export { authenticate, decodeAuth, encodeAuthFromPassword, - authenticationMiddlewareServer, - authenticationMiddlewareClient, startConnection, handleConnection, createClientServer, diff --git a/tests/RPC/RPCClient.test.ts b/tests/RPC/RPCClient.test.ts index 90c51335e..13edb94f2 100644 --- a/tests/RPC/RPCClient.test.ts +++ b/tests/RPC/RPCClient.test.ts @@ -18,7 +18,7 @@ import { ServerCaller, UnaryCaller, } from '@/RPC/callers'; -import * as rpcUtils from '@/RPC/utils'; +import * as middlewareUtils from '@/RPC/middleware'; import * as rpcTestUtils from './utils'; describe(`${RPCClient.name}`, () => { @@ -383,7 +383,7 @@ describe(`${RPCClient.name}`, () => { const rpcClient = await RPCClient.createRPCClient({ manifest: {}, streamPairCreateCallback: async () => streamPair, - middleware: rpcUtils.defaultClientMiddlewareWrapper(() => { + middleware: middlewareUtils.defaultClientMiddlewareWrapper(() => { return { forward: new TransformStream({ transform: (chunk, controller) => { @@ -447,7 +447,7 @@ describe(`${RPCClient.name}`, () => { const rpcClient = await RPCClient.createRPCClient({ manifest: {}, streamPairCreateCallback: async () => streamPair, - middleware: rpcUtils.defaultClientMiddlewareWrapper(() => { + middleware: middlewareUtils.defaultClientMiddlewareWrapper(() => { return { forward: new TransformStream(), reverse: new TransformStream({ diff --git a/tests/RPC/RPCServer.test.ts b/tests/RPC/RPCServer.test.ts index 504d8b8db..8ad825725 100644 --- a/tests/RPC/RPCServer.test.ts +++ b/tests/RPC/RPCServer.test.ts @@ -22,6 +22,7 @@ import { ServerHandler, UnaryHandler, } from '@/RPC/handlers'; +import * as middlewareUtils from '@/RPC/middleware'; import * as rpcTestUtils from './utils'; describe(`${RPCServer.name}`, () => { @@ -497,7 +498,7 @@ describe(`${RPCServer.name}`, () => { yield* input; } } - const middleware = rpcUtils.defaultServerMiddlewareWrapper(() => { + const middleware = middlewareUtils.defaultServerMiddlewareWrapper(() => { return { forward: new TransformStream({ transform: (chunk, controller) => { @@ -542,7 +543,7 @@ describe(`${RPCServer.name}`, () => { yield* input; } } - const middleware = rpcUtils.defaultServerMiddlewareWrapper(() => { + const middleware = middlewareUtils.defaultServerMiddlewareWrapper(() => { return { forward: new TransformStream(), reverse: new TransformStream({ @@ -590,7 +591,7 @@ describe(`${RPCServer.name}`, () => { yield* input; } } - const middleware = rpcUtils.defaultServerMiddlewareWrapper(() => { + const middleware = middlewareUtils.defaultServerMiddlewareWrapper(() => { let first = true; let reverseController: TransformStreamDefaultController; return { diff --git a/tests/RPC/middleware.test.ts b/tests/RPC/middleware.test.ts new file mode 100644 index 000000000..754725d71 --- /dev/null +++ b/tests/RPC/middleware.test.ts @@ -0,0 +1,96 @@ +import { fc, testProp } from '@fast-check/jest'; +import { AsyncIterableX as AsyncIterable } from 'ix/asynciterable'; +import * as rpcUtils from '@/RPC/utils'; +import 'ix/add/asynciterable-operators/toarray'; +import * as rpcErrors from '@/RPC/errors'; +import * as middleware from '@/RPC/middleware'; +import * as rpcTestUtils from './utils'; + +describe('Middleware tests', () => { + const noiseArb = fc + .array( + fc.uint8Array({ minLength: 5 }).map((array) => Buffer.from(array)), + { minLength: 5 }, + ) + .noShrink(); + + testProp( + 'can parse json stream', + [rpcTestUtils.jsonMessagesArb], + async (messages) => { + const parsedStream = rpcTestUtils + .messagesToReadableStream(messages) + .pipeThrough( + middleware.binaryToJsonMessageStream(rpcUtils.parseJsonRpcMessage), + ); // Converting back. + + const asd = await AsyncIterable.as(parsedStream).toArray(); + expect(asd).toEqual(messages); + }, + { numRuns: 1000 }, + ); + testProp( + 'Message size limit is enforced when parsing', + [ + fc.array( + rpcTestUtils.jsonRpcRequestMessageArb(fc.string({ minLength: 100 })), + { + minLength: 1, + }, + ), + ], + async (messages) => { + const parsedStream = rpcTestUtils + .messagesToReadableStream(messages) + .pipeThrough(rpcTestUtils.binaryStreamToSnippedStream([10])) + .pipeThrough( + middleware.binaryToJsonMessageStream( + rpcUtils.parseJsonRpcMessage, + 50, + ), + ); + + const doThing = async () => { + for await (const _ of parsedStream) { + // No touch, only consume + } + }; + await expect(doThing()).rejects.toThrow(rpcErrors.ErrorRpcMessageLength); + }, + { numRuns: 1000 }, + ); + testProp( + 'can parse json stream with random chunk sizes', + [rpcTestUtils.jsonMessagesArb, rpcTestUtils.snippingPatternArb], + async (messages, snippattern) => { + const parsedStream = rpcTestUtils + .messagesToReadableStream(messages) + .pipeThrough(rpcTestUtils.binaryStreamToSnippedStream(snippattern)) // Imaginary internet here + .pipeThrough( + middleware.binaryToJsonMessageStream(rpcUtils.parseJsonRpcMessage), + ); // Converting back. + + const asd = await AsyncIterable.as(parsedStream).toArray(); + expect(asd).toStrictEqual(messages); + }, + { numRuns: 1000 }, + ); + testProp( + 'Will error on bad data', + [rpcTestUtils.jsonMessagesArb, rpcTestUtils.snippingPatternArb, noiseArb], + async (messages, snippattern, noise) => { + const parsedStream = rpcTestUtils + .messagesToReadableStream(messages) + .pipeThrough(rpcTestUtils.binaryStreamToSnippedStream(snippattern)) // Imaginary internet here + .pipeThrough(rpcTestUtils.binaryStreamToNoisyStream(noise)) // Adding bad data to the stream + .pipeThrough( + middleware.binaryToJsonMessageStream(rpcUtils.parseJsonRpcMessage), + ); // Converting back. + + await expect(AsyncIterable.as(parsedStream).toArray()).rejects.toThrow( + rpcErrors.ErrorRpcParse, + ); + }, + { numRuns: 1000 }, + ); +}); diff --git a/tests/RPC/utils.test.ts b/tests/RPC/utils.test.ts index b0414dcee..75b09415e 100644 --- a/tests/RPC/utils.test.ts +++ b/tests/RPC/utils.test.ts @@ -1,70 +1,11 @@ -import { testProp, fc } from '@fast-check/jest'; +import { testProp } from '@fast-check/jest'; import { AsyncIterableX as AsyncIterable } from 'ix/asynciterable'; import * as rpcUtils from '@/RPC/utils'; import 'ix/add/asynciterable-operators/toarray'; -import * as rpcErrors from '@/RPC/errors'; +import * as middleware from '@/RPC/middleware'; import * as rpcTestUtils from './utils'; describe('utils tests', () => { - testProp( - 'can parse json stream', - [rpcTestUtils.jsonMessagesArb], - async (messages) => { - const parsedStream = rpcTestUtils - .messagesToReadableStream(messages) - .pipeThrough( - rpcUtils.binaryToJsonMessageStream(rpcUtils.parseJsonRpcMessage), - ); // Converting back. - - const asd = await AsyncIterable.as(parsedStream).toArray(); - expect(asd).toEqual(messages); - }, - { numRuns: 1000 }, - ); - - testProp( - 'can parse json stream with random chunk sizes', - [rpcTestUtils.jsonMessagesArb, rpcTestUtils.snippingPatternArb], - async (messages, snippattern) => { - const parsedStream = rpcTestUtils - .messagesToReadableStream(messages) - .pipeThrough(rpcTestUtils.binaryStreamToSnippedStream(snippattern)) // Imaginary internet here - .pipeThrough( - rpcUtils.binaryToJsonMessageStream(rpcUtils.parseJsonRpcMessage), - ); // Converting back. - - const asd = await AsyncIterable.as(parsedStream).toArray(); - expect(asd).toStrictEqual(messages); - }, - { numRuns: 1000 }, - ); - - const noiseArb = fc - .array( - fc.uint8Array({ minLength: 5 }).map((array) => Buffer.from(array)), - { minLength: 5 }, - ) - .noShrink(); - - testProp( - 'Will error on bad data', - [rpcTestUtils.jsonMessagesArb, rpcTestUtils.snippingPatternArb, noiseArb], - async (messages, snippattern, noise) => { - const parsedStream = rpcTestUtils - .messagesToReadableStream(messages) - .pipeThrough(rpcTestUtils.binaryStreamToSnippedStream(snippattern)) // Imaginary internet here - .pipeThrough(rpcTestUtils.binaryStreamToNoisyStream(noise)) // Adding bad data to the stream - .pipeThrough( - rpcUtils.binaryToJsonMessageStream(rpcUtils.parseJsonRpcMessage), - ); // Converting back. - - await expect(AsyncIterable.as(parsedStream).toArray()).rejects.toThrow( - rpcErrors.ErrorRpcParse, - ); - }, - { numRuns: 1000 }, - ); - testProp( 'can parse messages', [rpcTestUtils.jsonRpcMessageArb()], @@ -74,65 +15,6 @@ describe('utils tests', () => { { numRuns: 1000 }, ); - testProp( - 'Message size limit is enforced', - [ - fc.array( - rpcTestUtils.jsonRpcRequestMessageArb(fc.string({ minLength: 100 })), - { - minLength: 1, - }, - ), - ], - async (messages) => { - const parsedStream = rpcTestUtils - .messagesToReadableStream(messages) - .pipeThrough(rpcTestUtils.binaryStreamToSnippedStream([10])) - .pipeThrough( - rpcUtils.binaryToJsonMessageStream(rpcUtils.parseJsonRpcMessage, 50), - ); - - const doThing = async () => { - for await (const _ of parsedStream) { - // No touch, only consume - } - }; - await expect(doThing()).rejects.toThrow(rpcErrors.ErrorRpcMessageLength); - }, - { numRuns: 1000 }, - ); - - testProp( - 'merging transformation stream', - [fc.array(fc.integer()), fc.array(fc.integer())], - async (set1, set2) => { - const [outputResult, outputWriterStream] = - rpcTestUtils.streamToArray(); - const { controllerP, controllerTransformStream } = - rpcUtils.controllerTransformationFactory(); - void controllerTransformStream.readable - .pipeTo(outputWriterStream) - .catch(() => {}); - const writer = controllerTransformStream.writable.getWriter(); - const controller = await controllerP; - const expectedResult: Array = []; - for (let i = 0; i < Math.max(set1.length, set2.length); i++) { - if (set1[i] != null) { - await writer.write(set1[i]); - expectedResult.push(set1[i]); - } - if (set2[i] != null) { - controller.enqueue(set2[i]); - expectedResult.push(set2[i]); - } - } - await writer.close(); - - expect(await outputResult).toStrictEqual(expectedResult); - }, - { numRuns: 1000 }, - ); - testProp( 'can get the head message', [rpcTestUtils.jsonMessagesArb], @@ -144,7 +26,7 @@ describe('utils tests', () => { .pipeThrough(rpcTestUtils.binaryStreamToSnippedStream([7])) .pipeThrough(headTransformStream) .pipeThrough( - rpcUtils.binaryToJsonMessageStream(rpcUtils.parseJsonRpcMessage), + middleware.binaryToJsonMessageStream(rpcUtils.parseJsonRpcMessage), ); // Converting back. expect(await firstMessageProm).toStrictEqual(messages[0]); diff --git a/tests/clientRPC/authenticationMiddleware.test.ts b/tests/clientRPC/authenticationMiddleware.test.ts new file mode 100644 index 000000000..3e5d778c2 --- /dev/null +++ b/tests/clientRPC/authenticationMiddleware.test.ts @@ -0,0 +1,154 @@ +import type { Server } from 'https'; +import type { WebSocketServer } from 'ws'; +import type { RPCRequestParams, RPCResponseResult } from '@/clientRPC/types'; +import fs from 'fs'; +import path from 'path'; +import os from 'os'; +import { createServer } from 'https'; +import Logger, { LogLevel, StreamHandler } from '@matrixai/logger'; +import { DB } from '@matrixai/db'; +import KeyRing from '@/keys/KeyRing'; +import * as keysUtils from '@/keys/utils'; +import RPCServer from '@/RPC/RPCServer'; +import TaskManager from '@/tasks/TaskManager'; +import CertManager from '@/keys/CertManager'; +import RPCClient from '@/RPC/RPCClient'; +import { Session, SessionManager } from '@/sessions'; +import * as clientRPCUtils from '@/clientRPC/utils'; +import * as authMiddleware from '@/clientRPC/authenticationMiddleware'; +import { UnaryCaller } from '@/RPC/callers'; +import { UnaryHandler } from '@/RPC/handlers'; +import * as middlewareUtils from '@/RPC/middleware'; +import * as testsUtils from '../utils'; + +describe('agentUnlock', () => { + const logger = new Logger('agentUnlock test', LogLevel.WARN, [ + new StreamHandler(), + ]); + const password = 'helloworld'; + let dataDir: string; + let db: DB; + let keyRing: KeyRing; + let taskManager: TaskManager; + let certManager: CertManager; + let session: Session; + let sessionManager: SessionManager; + let server: Server; + let wss: WebSocketServer; + let port: number; + + beforeEach(async () => { + dataDir = await fs.promises.mkdtemp( + path.join(os.tmpdir(), 'polykey-test-'), + ); + const keysPath = path.join(dataDir, 'keys'); + const dbPath = path.join(dataDir, 'db'); + const sessionPath = path.join(dataDir, 'session'); + db = await DB.createDB({ + dbPath, + logger, + }); + keyRing = await KeyRing.createKeyRing({ + password, + keysPath, + logger, + passwordOpsLimit: keysUtils.passwordOpsLimits.min, + passwordMemLimit: keysUtils.passwordMemLimits.min, + strictMemoryLock: false, + }); + taskManager = await TaskManager.createTaskManager({ db, logger }); + certManager = await CertManager.createCertManager({ + db, + keyRing, + taskManager, + logger, + }); + session = await Session.createSession({ + sessionTokenPath: sessionPath, + logger, + }); + sessionManager = await SessionManager.createSessionManager({ + db, + keyRing, + logger, + }); + const tlsConfig = await testsUtils.createTLSConfig(keyRing.keyPair); + server = createServer({ + cert: tlsConfig.certChainPem, + key: tlsConfig.keyPrivatePem, + }); + port = await clientRPCUtils.listen(server, '127.0.0.1'); + }); + afterEach(async () => { + wss?.close(); + server.close(); + await certManager.stop(); + await taskManager.stop(); + await keyRing.stop(); + await db.stop(); + await fs.promises.rm(dataDir, { + force: true, + recursive: true, + }); + }); + test('get status', async () => { + // Setup + class EchoHandler extends UnaryHandler< + { logger: Logger }, + RPCRequestParams, + RPCResponseResult + > { + public async handle(input: RPCRequestParams): Promise { + return input; + } + } + const rpcServer = await RPCServer.createRPCServer({ + manifest: { + agentUnlock: new EchoHandler({ logger }), + }, + middleware: middlewareUtils.defaultServerMiddlewareWrapper( + authMiddleware.authenticationMiddlewareServer(sessionManager, keyRing), + ), + logger, + }); + wss = clientRPCUtils.createClientServer( + server, + rpcServer, + logger.getChild('server'), + ); + const rpcClient = await RPCClient.createRPCClient({ + manifest: { + agentUnlock: new UnaryCaller(), + }, + streamPairCreateCallback: async () => { + return clientRPCUtils.startConnection( + '127.0.0.1', + port, + logger.getChild('client'), + ); + }, + middleware: middlewareUtils.defaultClientMiddlewareWrapper( + authMiddleware.authenticationMiddlewareClient(session), + ), + logger, + }); + + // Doing the test + const result = await rpcClient.methods.agentUnlock({ + metadata: { + Authorization: clientRPCUtils.encodeAuthFromPassword(password), + }, + }); + expect(result).toMatchObject({ + metadata: { + Authorization: expect.any(String), + }, + }); + const result2 = await rpcClient.methods.agentUnlock({}); + expect(result2).toMatchObject({ + metadata: { + Authorization: expect.any(String), + }, + }); + }); +}); diff --git a/tests/clientRPC/handlers/agentUnlock.test.ts b/tests/clientRPC/handlers/agentUnlock.test.ts index 45939e327..1b592af3b 100644 --- a/tests/clientRPC/handlers/agentUnlock.test.ts +++ b/tests/clientRPC/handlers/agentUnlock.test.ts @@ -18,7 +18,8 @@ import { import RPCClient from '@/RPC/RPCClient'; import { Session, SessionManager } from '@/sessions'; import * as clientRPCUtils from '@/clientRPC/utils'; -import * as rpcUtils from '@/RPC/utils'; +import * as authMiddleware from '@/clientRPC/authenticationMiddleware'; +import * as middlewareUtils from '@/RPC/middleware'; import * as testsUtils from '../../utils'; describe('agentUnlock', () => { @@ -91,14 +92,14 @@ describe('agentUnlock', () => { recursive: true, }); }); - test('get status', async () => { + test('unlock', async () => { // Setup const rpcServer = await RPCServer.createRPCServer({ manifest: { agentUnlock: new AgentUnlockHandler({ logger }), }, - middleware: rpcUtils.defaultServerMiddlewareWrapper( - clientRPCUtils.authenticationMiddlewareServer(sessionManager, keyRing), + middleware: middlewareUtils.defaultServerMiddlewareWrapper( + authMiddleware.authenticationMiddlewareServer(sessionManager, keyRing), ), logger, }); @@ -118,8 +119,8 @@ describe('agentUnlock', () => { logger.getChild('client'), ); }, - middleware: rpcUtils.defaultClientMiddlewareWrapper( - clientRPCUtils.authenticationMiddlewareClient(session), + middleware: middlewareUtils.defaultClientMiddlewareWrapper( + authMiddleware.authenticationMiddlewareClient(session), ), logger, }); From 71833a0afe1b9c708f08f59e4e6228fc95c76312 Mon Sep 17 00:00:00 2001 From: Brian Botha Date: Tue, 14 Feb 2023 13:15:28 +1100 Subject: [PATCH 44/44] feat: updated client caller interfaces - Related #501 --- src/RPC/RPCClient.ts | 74 +++------------ src/RPC/types.ts | 24 +---- tests/RPC/RPC.test.ts | 4 +- tests/RPC/RPCClient.test.ts | 173 +++++++++++------------------------- 4 files changed, 68 insertions(+), 207 deletions(-) diff --git a/src/RPC/RPCClient.ts b/src/RPC/RPCClient.ts index f0a684a02..91f2c729d 100644 --- a/src/RPC/RPCClient.ts +++ b/src/RPC/RPCClient.ts @@ -3,10 +3,13 @@ import type { JsonRpcRequestMessage, StreamPairCreateCallback, ClientManifest, - MapRawCallers, } from './types'; import type { JSONValue } from 'types'; -import type { ReadableWritablePair, WritableStream } from 'stream/web'; +import type { + ReadableWritablePair, + WritableStream, + ReadableStream, +} from 'stream/web'; import type { JsonRpcRequest, JsonRpcResponse, @@ -78,27 +81,9 @@ class RPCClient { case 'CLIENT': return () => this.clientStreamCaller(method); case 'DUPLEX': - return (f) => this.duplexStreamCaller(method, f); + return () => this.duplexStreamCaller(method); case 'RAW': - default: - return; - } - }, - }, - ); - protected rawMethodsProxy = new Proxy( - {}, - { - get: (_, method) => { - if (typeof method === 'symbol') throw never(); - switch (this.callerTypes[method]) { - case 'DUPLEX': - return () => this.rawDuplexStreamCaller(method); - case 'RAW': - return (params) => this.rawStreamCaller(method, params); - case 'SERVER': - case 'CLIENT': - case 'UNARY': + return (header) => this.rawStreamCaller(method, header); default: return; } @@ -138,17 +123,12 @@ class RPCClient { return this.methodsProxy as MapCallers; } - @ready(new rpcErrors.ErrorRpcDestroyed()) - public get rawMethods(): MapRawCallers { - return this.rawMethodsProxy as MapRawCallers; - } - @ready(new rpcErrors.ErrorRpcDestroyed()) public async unaryCaller( method: string, parameters: I, ): Promise { - const callerInterface = await this.rawDuplexStreamCaller(method); + const callerInterface = await this.duplexStreamCaller(method); const reader = callerInterface.readable.getReader(); const writer = callerInterface.writable.getWriter(); await writer.write(parameters); @@ -165,18 +145,13 @@ class RPCClient { public async serverStreamCaller( method: string, parameters: I, - ): Promise> { - const callerInterface = await this.rawDuplexStreamCaller(method); + ): Promise> { + const callerInterface = await this.duplexStreamCaller(method); const writer = callerInterface.writable.getWriter(); await writer.write(parameters); await writer.close(); - const outputGen = async function* () { - for await (const value of callerInterface.readable) { - yield value; - } - }; - return outputGen(); + return callerInterface.readable; } @ready(new rpcErrors.ErrorRpcDestroyed()) @@ -186,7 +161,7 @@ class RPCClient { output: Promise; writable: WritableStream; }> { - const callerInterface = await this.rawDuplexStreamCaller(method); + const callerInterface = await this.duplexStreamCaller(method); const reader = callerInterface.readable.getReader(); const output = reader.read().then(({ value, done }) => { if (done) { @@ -203,27 +178,6 @@ class RPCClient { @ready(new rpcErrors.ErrorRpcDestroyed()) public async duplexStreamCaller( method: string, - f: (output: AsyncIterable) => AsyncIterable, - ): Promise { - const callerInterface = await this.rawDuplexStreamCaller(method); - const outputGenerator = async function* () { - for await (const value of callerInterface.readable) { - yield value; - } - }; - const writer = callerInterface.writable.getWriter(); - try { - for await (const value of f(outputGenerator())) { - await writer.write(value); - } - } finally { - await writer.close(); - } - } - - @ready(new rpcErrors.ErrorRpcDestroyed()) - public async rawDuplexStreamCaller( - method: string, ): Promise> { const outputMessageTransformStream = clientOutputTransformStream(); const inputMessageTransformStream = clientInputTransformStream(method); @@ -249,14 +203,14 @@ class RPCClient { @ready(new rpcErrors.ErrorRpcDestroyed()) public async rawStreamCaller( method: string, - params: JSONValue, + headerParams: JSONValue, ): Promise> { const streamPair = await this.streamPairCreateCallback(); const tempWriter = streamPair.writable.getWriter(); const header: JsonRpcRequestMessage = { jsonrpc: '2.0', method, - params, + params: headerParams, id: null, }; await tempWriter.write(Buffer.from(JSON.stringify(header))); diff --git a/src/RPC/types.ts b/src/RPC/types.ts index 4e6633527..4d96fcc0c 100644 --- a/src/RPC/types.ts +++ b/src/RPC/types.ts @@ -156,7 +156,7 @@ type UnaryCallerImplementation< type ServerCallerImplementation< I extends JSONValue = JSONValue, O extends JSONValue = JSONValue, -> = (parameters: I) => Promise>; +> = (parameters: I) => Promise>; type ClientCallerImplementation< I extends JSONValue = JSONValue, @@ -166,17 +166,10 @@ type ClientCallerImplementation< type DuplexCallerImplementation< I extends JSONValue = JSONValue, O extends JSONValue = JSONValue, -> = (f: (output: AsyncIterable) => AsyncIterable) => Promise; - -// Raw callers - -type RawDuplexCallerImplementation< - I extends JSONValue = JSONValue, - O extends JSONValue = JSONValue, > = () => Promise>; type RawCallerImplementation = ( - params: JSONValue, + headerParams: JSONValue, ) => Promise>; type ConvertDuplexCaller = T extends DuplexCaller @@ -203,14 +196,6 @@ type ConvertCaller = T extends DuplexCaller ? ConvertClientCaller : T extends UnaryCaller ? ConvertUnaryCaller - : never; - -type ConvertRawDuplexStreamHandler = T extends DuplexCaller - ? RawDuplexCallerImplementation - : never; - -type ConvertRawCaller = T extends DuplexCaller - ? ConvertRawDuplexStreamHandler : T extends RawCaller ? RawCallerImplementation : never; @@ -224,10 +209,6 @@ type MapCallers = { [K in keyof T]: ConvertCaller; }; -type MapRawCallers = { - [K in keyof T]: ConvertRawCaller; -}; - export type { JsonRpcRequestMessage, JsonRpcRequestNotification, @@ -250,5 +231,4 @@ export type { ClientManifest, HandlerType, MapCallers, - MapRawCallers, }; diff --git a/tests/RPC/RPC.test.ts b/tests/RPC/RPC.test.ts index 28c9704e2..883b4e4f1 100644 --- a/tests/RPC/RPC.test.ts +++ b/tests/RPC/RPC.test.ts @@ -63,7 +63,7 @@ describe('RPC', () => { logger, }); - const callerInterface = await rpcClient.rawMethods.testMethod({ + const callerInterface = await rpcClient.methods.testMethod({ hello: 'world', }); const writer = callerInterface.writable.getWriter(); @@ -116,7 +116,7 @@ describe('RPC', () => { logger, }); - const callerInterface = await rpcClient.rawMethods.testMethod(); + const callerInterface = await rpcClient.methods.testMethod(); const writer = callerInterface.writable.getWriter(); const reader = callerInterface.readable.getReader(); for (const value of values) { diff --git a/tests/RPC/RPCClient.test.ts b/tests/RPC/RPCClient.test.ts index 13edb94f2..672f10626 100644 --- a/tests/RPC/RPCClient.test.ts +++ b/tests/RPC/RPCClient.test.ts @@ -98,12 +98,15 @@ describe(`${RPCClient.name}`, () => { streamPairCreateCallback: async () => streamPair, logger, }); - await rpcClient.duplexStreamCaller( - methodName, - async function* (output) { - yield* output; - }, - ); + const callerInterface = await rpcClient.duplexStreamCaller< + JSONValue, + JSONValue + >(methodName); + const writable = callerInterface.writable.getWriter(); + for await (const value of callerInterface.readable) { + await writable.write(value); + } + await writable.close(); const expectedMessages: Array = messages.map((v) => { const request: JsonRpcRequestMessage = { @@ -252,14 +255,16 @@ describe(`${RPCClient.name}`, () => { streamPairCreateCallback: async () => streamPair, logger, }); - const callProm = rpcClient.duplexStreamCaller( - methodName, - async function* (output) { - for await (const _ of output) { - // No touch, just consume - } - }, - ); + const callerInterface = await rpcClient.duplexStreamCaller< + JSONValue, + JSONValue + >(methodName); + await callerInterface.writable.close(); + const callProm = (async () => { + for await (const _ of callerInterface.readable) { + // Only consume + } + })(); await expect(callProm).rejects.toThrow(rpcErrors.ErrorRpcRemoteError); await outputResult; await rpcClient.destroy(); @@ -287,14 +292,16 @@ describe(`${RPCClient.name}`, () => { streamPairCreateCallback: async () => streamPair, logger, }); - const callProm = rpcClient.duplexStreamCaller( - methodName, - async function* (output) { - for await (const _ of output) { - // No touch, just consume - } - }, - ); + const callerInterface = await rpcClient.duplexStreamCaller< + JSONValue, + JSONValue + >(methodName); + await callerInterface.writable.close(); + const callProm = (async () => { + for await (const _ of callerInterface.readable) { + // Only consume + } + })(); await expect(callProm).rejects.toThrow(rpcErrors.ErrorRpcRemoteError); await outputResult; await rpcClient.destroy(); @@ -325,50 +332,21 @@ describe(`${RPCClient.name}`, () => { streamPairCreateCallback: async () => streamPair, logger, }); - const callProm = rpcClient.duplexStreamCaller( - methodName, - async function* (output) { - for await (const _ of output) { - // No touch, just consume - } - }, - ); + const callerInterface = await rpcClient.duplexStreamCaller< + JSONValue, + JSONValue + >(methodName); + await callerInterface.writable.close(); + const callProm = (async () => { + for await (const _ of callerInterface.readable) { + // Only consume + } + })(); await expect(callProm).rejects.toThrow(rpcErrors.ErrorRpcRemoteError); await outputResult; await rpcClient.destroy(); }, ); - testProp( - 'rawDuplexStreamCaller', - [fc.array(rpcTestUtils.jsonRpcResponseResultArb(), { minLength: 1 })], - async (messages) => { - const inputStream = rpcTestUtils.messagesToReadableStream(messages); - const [outputResult, outputStream] = - rpcTestUtils.streamToArray(); - const streamPair: ReadableWritablePair = { - readable: inputStream, - writable: outputStream, - }; - const rpcClient = await RPCClient.createRPCClient({ - manifest: {}, - streamPairCreateCallback: async () => streamPair, - logger, - }); - let count = 0; - const callerInterface = await rpcClient.rawDuplexStreamCaller(methodName); - const writer = callerInterface.writable.getWriter(); - for await (const val of callerInterface.readable) { - count += 1; - await writer.write(val); - } - await writer.close(); - const result = await outputResult; - // We're just checking that it's consuming the messages as expected - expect(result.length).toEqual(messages.length); - expect(count).toEqual(messages.length); - await rpcClient.destroy(); - }, - ); testProp( 'generic duplex caller with forward Middleware', [specificMessageArb], @@ -399,7 +377,7 @@ describe(`${RPCClient.name}`, () => { logger, }); - const callerInterface = await rpcClient.rawDuplexStreamCaller< + const callerInterface = await rpcClient.duplexStreamCaller< JSONValue, JSONValue >(methodName); @@ -463,7 +441,7 @@ describe(`${RPCClient.name}`, () => { logger, }); - const callerInterface = await rpcClient.rawDuplexStreamCaller< + const callerInterface = await rpcClient.duplexStreamCaller< JSONValue, JSONValue >(methodName); @@ -483,58 +461,6 @@ describe(`${RPCClient.name}`, () => { await rpcClient.destroy(); }, ); - testProp( - 'manifest raw duplex call', - [ - fc.array(rpcTestUtils.jsonRpcResponseResultArb(fc.string()), { - minLength: 5, - }), - ], - async (messages) => { - const inputStream = rpcTestUtils.messagesToReadableStream(messages); - const [outputResult, outputStream] = rpcTestUtils.streamToArray(); - const streamPair: ReadableWritablePair = { - readable: inputStream, - writable: outputStream, - }; - const rpcClient = await RPCClient.createRPCClient({ - manifest: { - duplex: new DuplexCaller(), - }, - streamPairCreateCallback: async () => streamPair, - logger, - }); - const callerInterface = await rpcClient.rawMethods.duplex(); - const reader = callerInterface.readable.getReader(); - const writer = callerInterface.writable.getWriter(); - while (true) { - const { value, done } = await reader.read(); - if (done) { - // We have to end the writer otherwise the stream never closes - await writer.close(); - break; - } - await writer.write(value); - } - const expectedMessages: Array = messages.map( - (v) => { - const request: JsonRpcRequestMessage = { - jsonrpc: '2.0', - method: 'duplex', - id: null, - ...(v.result === undefined ? {} : { params: v.result }), - }; - return request; - }, - ); - const outputMessages = (await outputResult).map((v) => - JSON.parse(v.toString()), - ); - expect(outputMessages).toStrictEqual(expectedMessages); - - await rpcClient.destroy(); - }, - ); testProp( 'manifest server call', [specificMessageArb, fc.string()], @@ -643,7 +569,7 @@ describe(`${RPCClient.name}`, () => { }, ); testProp( - 'manifest raw duplex caller', + 'manifest raw caller', [ rpcTestUtils.safeJsonValueArb, rpcTestUtils.rawDataArb, @@ -672,7 +598,7 @@ describe(`${RPCClient.name}`, () => { streamPairCreateCallback: async () => streamPair, logger, }); - const callerInterface = await rpcClient.rawMethods.raw(headerParams); + const callerInterface = await rpcClient.methods.raw(headerParams); await callerInterface.readable.pipeTo(outputWritableStream); const writer = callerInterface.writable.getWriter(); for (const inputDatum of inputData) { @@ -716,12 +642,13 @@ describe(`${RPCClient.name}`, () => { logger, }); let count = 0; - await rpcClient.methods.duplex(async function* (output) { - for await (const value of output) { - count += 1; - yield value; - } - }); + const callerInterface = await rpcClient.methods.duplex(); + const writer = callerInterface.writable.getWriter(); + for await (const value of callerInterface.readable) { + count += 1; + await writer.write(value); + } + await writer.close(); const result = await outputResult; // We're just checking that it's consuming the messages as expected expect(result.length).toEqual(messages.length);