Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 61 additions & 52 deletions lokerpc/codegen/go.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,38 @@ func GenGoType(schema jtd.Schema, imports map[string]struct{}) string {
return t
}

type resolvedMethod struct {
reqType string
resType string
isVoid bool
}

// resolveMethodTypes determines the Go request and response types for an endpoint,
// including whether the method has a void return type.
func resolveMethodTypes(v lokerpc.EndpointMeta, imports map[string]struct{}) resolvedMethod {
reqType := "any"
if v.RequestTypeDef != nil {
reqType = GenGoType(*v.RequestTypeDef, imports)
}

resType := "any"
isVoid := false
if v.ResponseTypeDef != nil {
if v.ResponseTypeDef.Metadata["void"] == true {
isVoid = true
resType = ""
} else {
resType = GenGoType(*v.ResponseTypeDef, imports)

if !strings.HasPrefix(resType, "[]") && !strings.HasPrefix(resType, "map[") && !strings.HasPrefix(resType, "*") {
resType = "*" + resType
}
}
}

return resolvedMethod{reqType: reqType, resType: resType, isVoid: isVoid}
}

func GenGoClient(w io.Writer, meta lokerpc.Meta) error {
defOrder := normalise(&meta)

Expand All @@ -97,26 +129,14 @@ func GenGoClient(w io.Writer, meta lokerpc.Meta) error {
// goDocComment(b, meta.Help, "")
b.WriteString("type " + goFieldName(meta.ServiceName) + "Service interface {\n")
for _, v := range meta.Interfaces {
reqType := "any"
if v.RequestTypeDef != nil {
reqType = GenGoType(*v.RequestTypeDef, imports)
}

resType := "any"
if v.ResponseTypeDef != nil {
if v.ResponseTypeDef.Metadata["void"] == true {
resType = "struct{}"
} else {
resType = GenGoType(*v.ResponseTypeDef, imports)

if !strings.HasPrefix(resType, "[]") && !strings.HasPrefix(resType, "map[") && !strings.HasPrefix(resType, "*") {
resType = "*" + resType
}
}
}
m := resolveMethodTypes(v, imports)

// goDocComment(b, v.Help, "\t")
fmt.Fprintf(&b, "\t%s(context.Context, %s) (%s, error)\n", goFieldName(v.MethodName), reqType, resType)
if m.isVoid {
fmt.Fprintf(&b, "\t%s(context.Context, %s) error\n", goFieldName(v.MethodName), m.reqType)
} else {
fmt.Fprintf(&b, "\t%s(context.Context, %s) (%s, error)\n", goFieldName(v.MethodName), m.reqType, m.resType)
}
}
b.WriteString("}\n")

Expand All @@ -125,44 +145,33 @@ func GenGoClient(w io.Writer, meta lokerpc.Meta) error {
// goDocComment(b, meta.Help, "")
b.WriteString("type " + goFieldName(meta.ServiceName) + "RPCClient struct{\nlokerpc.Client}\n\n")
for _, v := range meta.Interfaces {
reqType := "any"
if v.RequestTypeDef != nil {
reqType = GenGoType(*v.RequestTypeDef, imports)
}

resType := "any"
if v.ResponseTypeDef != nil {
if v.ResponseTypeDef.Metadata["void"] == true {
resType = "struct{}"
} else {
resType = GenGoType(*v.ResponseTypeDef, imports)
m := resolveMethodTypes(v, imports)

if !strings.HasPrefix(resType, "[]") && !strings.HasPrefix(resType, "map[") && !strings.HasPrefix(resType, "*") {
resType = "*" + resType
}
if m.isVoid {
fmt.Fprintf(&b, "func (c %sRPCClient) %s(ctx context.Context, req %s) error {\n", goFieldName(meta.ServiceName), goFieldName(v.MethodName), m.reqType)
fmt.Fprintf(&b, "\treturn c.DoRequest(ctx, \"%s\", req, nil)\n", v.MethodName)
fmt.Fprintf(&b, "}\n")
} else {
varType := m.resType
if varType != "any" && strings.HasPrefix(varType, "*") {
varType = varType[1:]
}
}

varType := resType
if varType != "any" && strings.HasPrefix(varType, "*") {
varType = varType[1:]
}

// goDocComment(b, v.Help, "\t")
fmt.Fprintf(&b, "func (c %sRPCClient) %s(ctx context.Context, req %s) (%s, error) {\n", goFieldName(meta.ServiceName), goFieldName(v.MethodName), reqType, resType)
fmt.Fprintf(&b, "\tvar res %s\n", varType)
fmt.Fprintf(&b, "\terr := c.DoRequest(ctx, \"%s\", req, &res)\n", v.MethodName)
fmt.Fprintf(&b, "\tif err != nil {\n")
fmt.Fprintf(&b, "\t\treturn nil, err\n")
fmt.Fprintf(&b, "\t}\n")
if resType == "any" {
fmt.Fprintf(&b, "\treturn res, nil\n")
} else if strings.HasPrefix(resType, "*") {
fmt.Fprintf(&b, "\treturn &res, nil\n")
} else {
fmt.Fprintf(&b, "\treturn res, nil\n")
fmt.Fprintf(&b, "func (c %sRPCClient) %s(ctx context.Context, req %s) (%s, error) {\n", goFieldName(meta.ServiceName), goFieldName(v.MethodName), m.reqType, m.resType)
fmt.Fprintf(&b, "\tvar res %s\n", varType)
fmt.Fprintf(&b, "\terr := c.DoRequest(ctx, \"%s\", req, &res)\n", v.MethodName)
fmt.Fprintf(&b, "\tif err != nil {\n")
fmt.Fprintf(&b, "\t\treturn nil, err\n")
fmt.Fprintf(&b, "\t}\n")
if m.resType == "any" {
fmt.Fprintf(&b, "\treturn res, nil\n")
} else if strings.HasPrefix(m.resType, "*") {
fmt.Fprintf(&b, "\treturn &res, nil\n")
} else {
fmt.Fprintf(&b, "\treturn res, nil\n")
}
fmt.Fprintf(&b, "}\n")
}
fmt.Fprintf(&b, "}\n")
}

// Write header
Expand Down
22 changes: 20 additions & 2 deletions lokerpc/codegen/go_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,28 @@ func TestGenGoClient(t *testing.T) {
}

formatted, err := format.Source(buf.Bytes())
if err != nil {
// Some fixtures (e.g., spaces-hyphens.json) produce fields that are
// not valid Go identifiers. This is a known codegen limitation.
t.Skipf("generated code is not valid Go: %v", err)
}

goldenPath := p + ".go"
if os.Getenv("UPDATE_GOLDEN") != "" {
err = os.WriteFile(goldenPath, formatted, 0644)
if err != nil {
t.Fatal(err)
}
return
}

err = os.WriteFile(p+".go", formatted, 0644)
expected, err := os.ReadFile(goldenPath)
if err != nil {
t.Fatal(err)
t.Fatalf("golden file %s not found; run with UPDATE_GOLDEN=1 to create it", goldenPath)
}

if !bytes.Equal(formatted, expected) {
t.Errorf("generated output differs from %s; run with UPDATE_GOLDEN=1 to update", goldenPath)
}
})
}
Expand Down
11 changes: 3 additions & 8 deletions lokerpc/codegen/testdata/void.json.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,13 @@ import (
)

type Service1Service interface {
Hello1(context.Context, any) (struct{}, error)
Hello1(context.Context, any) error
}

type Service1RPCClient struct {
lokerpc.Client
}

func (c Service1RPCClient) Hello1(ctx context.Context, req any) (struct{}, error) {
var res struct{}
err := c.DoRequest(ctx, "hello1", req, &res)
if err != nil {
return nil, err
}
return res, nil
func (c Service1RPCClient) Hello1(ctx context.Context, req any) error {
return c.DoRequest(ctx, "hello1", req, nil)
}
42 changes: 40 additions & 2 deletions lokerpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ type EndpointCodec struct {
requestType reflect.Type
responseType reflect.Type
errOnNilResponse bool
voidResponse bool
}

// EndpointCodecMap maps the Request.Method to the proper EndpointCodec
Expand Down Expand Up @@ -129,6 +130,8 @@ func DecodeRequest[Req any](_ context.Context, msg json.RawMessage) (any, error)

type StandardMethod[Req any, Res any] func(context.Context, Req) (Res, error)

type VoidMethod[Req any] func(context.Context, Req) error

func MakeStandardEndpoint[Req any, Res any](method StandardMethod[Req, Res]) Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(Req)
Expand All @@ -137,6 +140,14 @@ func MakeStandardEndpoint[Req any, Res any](method StandardMethod[Req, Res]) End
}
}

func MakeVoidEndpoint[Req any](method VoidMethod[Req]) Endpoint {
return func(ctx context.Context, request any) (any, error) {
req := request.(Req)
err := method(ctx, req)
return standardResponse{nil, err}, nil
}
}

type EndpointCodecOption func(*EndpointCodec)

// MakeStandardEndpointCodec
Expand All @@ -161,6 +172,29 @@ func MakeStandardEndpointCodec[Req any, Res any](method StandardMethod[Req, Res]
return ec
}

// MakeVoidEndpointCodec creates an EndpointCodec for methods that return no value.
// The generated metadata will include "void": true on the response type.
func MakeVoidEndpointCodec[Req any](method VoidMethod[Req], help string, opts ...EndpointCodecOption) EndpointCodec {
var req Req

ec := EndpointCodec{
Endpoint: MakeVoidEndpoint(method),
Decode: DecodeRequest[Req],
ParamNames: FieldNames(req),
Help: help,

requestType: reflect.TypeOf(req),
responseType: nil,
voidResponse: true,
}

for _, opt := range opts {
opt(&ec)
}

return ec
}

func NoNilResponse() EndpointCodecOption {
return func(ec *EndpointCodec) {
ec.errOnNilResponse = true
Expand Down Expand Up @@ -258,7 +292,11 @@ func MountHandlers(logger log.Logger, mux Mux, services ...*Service) {
endMeta.RequestTypeDef = TypeSchema(ec.requestType, defs)
endMeta.RequestTypeDef.Nullable = false
}
if ec.responseType != nil {
if ec.voidResponse {
endMeta.ResponseTypeDef = &jtd.Schema{
Metadata: map[string]any{"void": true},
}
} else if ec.responseType != nil {
endMeta.ResponseTypeDef = TypeSchema(ec.responseType, defs)
if ec.errOnNilResponse {
endMeta.ResponseTypeDef.Nullable = false
Expand Down Expand Up @@ -401,7 +439,7 @@ func makeHandler(logger log.Logger, ec EndpointCodec) http.HandlerFunc {
result = r.Result()
}

if result == nil && ec.errOnNilResponse {
if result == nil && ec.errOnNilResponse && !ec.voidResponse {
logErr("err", "unexpected nil response")

status = http.StatusInternalServerError
Expand Down
Loading