Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions aisdk/ai/provider/openai/internal/codec/jsonschema.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,17 @@ func encodeSchema(schema *jsonschema.Schema) (map[string]any, error) {
return nil, nil
}

// Enforce OpenAI restrictions
// https://platform.openai.com/docs/guides/structured-outputs#root-objects-must-not-be-anyof-and-must-be-an-object
// NOTE: we could simply encode the input schema, pass it through to OpenAI and let it return an error, but there are
// other encoding rules we want to enforce later, and limiting the scope here allows us to limit the scope later.
if schema.Type != "object" {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I would create a helper validateSchema method. Up to you though.

return nil, fmt.Errorf("schema root must be of type object, got: %s", schema.Type)
}
if schema.AnyOf != nil {
return nil, fmt.Errorf("schema root cannot use AnyOf")
}

// Marshal to JSON and unmarshal back to interface{} to convert the types
data, err := json.Marshal(schema)
if err != nil {
Expand All @@ -32,6 +43,12 @@ func encodeSchema(schema *jsonschema.Schema) (map[string]any, error) {
return nil, fmt.Errorf("failed to unmarshal properties: %w\n\n%s", err, data)
}

// Ensure properties field is set, even if it's empty. It's unclear whether OpenAI requires
// this to be set for nested schema objects too. For now we only set it at the top-level.
if _, ok := result["properties"]; !ok {
result["properties"] = map[string]any{}
}

// Convert {"not": {}} patterns to false throughout the schema
normalizeSchemaMap(result)

Expand Down
110 changes: 69 additions & 41 deletions aisdk/ai/provider/openai/internal/codec/jsonschema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,20 +143,42 @@ func TestEncodeSchema(t *testing.T) {
}`,
},
{
name: "schema with allOf containing additionalProperties",
name: "schema with nested AnyOf",
input: &jsonschema.Schema{
AllOf: []*jsonschema.Schema{
{
Type: "object",
AdditionalProperties: api.FalseSchema(),
Type: "object",
Properties: map[string]*jsonschema.Schema{
"numeric": {
AnyOf: []*jsonschema.Schema{
{
Type: "string",
},
{
Type: "number",
},
},
},
},
},
want: `{
"allOf": [{
"type": "object",
"additionalProperties": false
}]
"type": "object",
"properties": {
"numeric": {
"anyOf": [
{ "type": "string" },
{ "type": "number" }
]
}
}
}`,
},
{
name: "schema without properties gets empty properties map",
input: &jsonschema.Schema{
Type: "object",
},
want: `{
"type": "object",
"properties": {}
}`,
},
{
Expand Down Expand Up @@ -210,6 +232,44 @@ func TestEncodeSchema(t *testing.T) {
"required": ["id"]
}`,
},

// Edge/error cases
{
name: "schema with non-object root",
input: &jsonschema.Schema{
Properties: map[string]*jsonschema.Schema{
"name": {
Type: "string",
Description: "The name",
},
},
},
wantErr: true,
},
{
name: "empty schema",
input: &jsonschema.Schema{},
wantErr: true,
},
{
name: "schema with only additional properties",
input: &jsonschema.Schema{
AdditionalProperties: api.FalseSchema(),
},
wantErr: true,
},
{
name: "schema with AnyOf at rool level",
input: &jsonschema.Schema{
AnyOf: []*jsonschema.Schema{
{
Type: "object",
AdditionalProperties: api.FalseSchema(),
},
},
},
wantErr: true,
},
}

for _, tt := range tests {
Expand Down Expand Up @@ -451,35 +511,3 @@ func TestNormalizeSchemaMap(t *testing.T) {
})
}
}

func TestEncodeSchema_EdgeCases(t *testing.T) {
t.Run("schema with only additionalProperties", func(t *testing.T) {
schema := &jsonschema.Schema{
AdditionalProperties: api.FalseSchema(),
}

got, err := encodeSchema(schema)
require.NoError(t, err)

gotJSON, err := json.Marshal(got)
require.NoError(t, err)

expectedJSON := `{
"additionalProperties": false
}`

assert.JSONEq(t, expectedJSON, string(gotJSON))
})

t.Run("empty schema", func(t *testing.T) {
schema := &jsonschema.Schema{}

got, err := encodeSchema(schema)
require.NoError(t, err)

gotJSON, err := json.Marshal(got)
require.NoError(t, err)

assert.JSONEq(t, "{}", string(gotJSON))
})
}
Loading