Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -1243,6 +1243,7 @@ golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxb
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U=
golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
Expand Down
187 changes: 110 additions & 77 deletions models/esmfold.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,25 @@ import (
"gobius/config"
"gobius/ipfs"
"gobius/utils"
"io"
"net/http"
"path/filepath"
"strings"
"time"

"os"
"os/exec"
"crypto/tls"

"github.com/google/uuid"
"github.com/mr-tron/base58"
"github.com/rs/zerolog"
)

type ESMFoldV1Input struct {
Sequence string `json:"sequence"`
type ESMFoldV1Inner struct {
Prompt string `json:"prompt"`
}

type ESMFoldV1Prompt struct {
Input ESMFoldV1Inner `json:"input"`
}

type ESMFoldV1Model struct {
Expand Down Expand Up @@ -51,7 +57,7 @@ var ESMFoldV1ModelTemplate = Model{
"version": 1,
"input": []map[string]any{
{
"variable": "sequence",
"variable": "prompt",
"type": "string",
"required": true,
"default": "",
Expand Down Expand Up @@ -79,40 +85,46 @@ func NewESMFoldV1Model(client ipfs.IPFSClient, appConfig *config.AppConfig, logg
return nil
}

http := &http.Client{
Transport: &http.Transport{MaxIdleConnsPerHost: 10},
}
httpClient := &http.Client{
Transport: &http.Transport{
MaxIdleConnsPerHost: 10,
ForceAttemptHTTP2: true,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
CheckRedirect: func(req *http.Request, via []*http.Request) error {
if len(via) >= 10 {
return fmt.Errorf("too many redirects")
}
logger.Debug().Str("redirect_url", req.URL.String()).Msg("Following redirect")
return nil
},
Timeout: 600 * time.Second,
}

timeout := 600 * time.Second
ipfsTimeout := 30 * time.Second

// Use model.ID (the hex string CID) as the key for the Cog map
cogConfig, ok := appConfig.ML.Cog[model.ID]
if ok {
// Parse inference timeout only if the string is not empty
if cogConfig.HttpTimeout != "" {
parsedTimeout, err := time.ParseDuration(cogConfig.HttpTimeout)
if err != nil {
logger.Warn().Err(err).Str("model", model.ID).Str("config_timeout", cogConfig.HttpTimeout).Msg("failed to parse model timeout from cog config, using default 120s")
// Keep default timeout
} else {
timeout = parsedTimeout
}
} // Else: HttpTimeout is empty, silently use the default

// Parse IPFS timeout only if the string is not empty
}

if cogConfig.IpfsTimeout != "" {
parsedIpfsTimeout, err := time.ParseDuration(cogConfig.IpfsTimeout)
if err != nil {
logger.Warn().Err(err).Str("model", model.ID).Str("config_ipfs_timeout", cogConfig.IpfsTimeout).Msg("failed to parse IPFS timeout from cog config, using default 30s")
// Keep default ipfsTimeout
} else {
ipfsTimeout = parsedIpfsTimeout
}
} // Else: IpfsTimeout is empty, silently use the default
}
}

// perform validation on the template
templateMeta, ok := ESMFoldV1ModelTemplate.Template.(map[string]any)
if !ok {
logger.Error().Str("model", model.ID).Msg("invalid template format")
Expand Down Expand Up @@ -144,7 +156,7 @@ func NewESMFoldV1Model(client ipfs.IPFSClient, appConfig *config.AppConfig, logg
},
},
ipfs: client,
client: http,
client: httpClient,
logger: logger,
inputFields: inputFields,
}
Expand Down Expand Up @@ -175,72 +187,93 @@ func (m *ESMFoldV1Model) HydrateInput(preprocessedInput map[string]any, seed uin
}
}

var inner ESMFoldV1Input
inner.Sequence, _ = input["sequence"].(string)
var inner ESMFoldV1Prompt
inner.Input.Prompt, _ = input["prompt"].(string)
return inner, nil
}

func (m *ESMFoldV1Model) GetID() string {
return m.Model.ID
}

func (m *ESMFoldV1Model) GetFiles(ctx context.Context, gpu *common.GPU, taskid string, input any) ([]ipfs.IPFSFile, error) {
if err := ctx.Err(); err != nil {
m.logger.Warn().Err(err).Str("task", taskid).Msg("Context canceled before GetFiles execution")
return nil, err
}

inner, ok := input.(ESMFoldV1Input)
if !ok {
return nil, fmt.Errorf("invalid input type: expected ESMFoldV1Input")
}

payload := map[string]string{"sequence": inner.Sequence}
marshaledInput, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("failed to marshal input: %w", err)
}

endpoint := fmt.Sprintf("%s:8080/predict", strings.TrimSuffix(gpu.Url, "/"))
req, err := http.NewRequestWithContext(ctx, "POST", endpoint, bytes.NewBuffer(marshaledInput))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")

resp, err := m.client.Do(req)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
m.logger.Error().Err(err).Str("task", taskid).Str("gpu", endpoint).Msg("model inference request timed out")
return nil, fmt.Errorf("model inference timed out: %w", err)
}
return nil, fmt.Errorf("failed to POST to GPU: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
if resp.StatusCode == http.StatusConflict {
m.logger.Warn().Str("task", taskid).Str("gpu", endpoint).Int("status", resp.StatusCode).Str("body", string(bodyBytes)).Msg("resource busy")
return nil, ErrResourceBusy
}
return nil, fmt.Errorf("server returned non-200 status: %d - %s", resp.StatusCode, string(bodyBytes))
}

pdbData, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read model response body: %w", err)
}

if len(pdbData) == 0 {
return nil, errors.New("model returned empty PDB data")
}

fileName := fmt.Sprintf("%d.%s.pdb", gpu.ID, uuid.New().String())
path := filepath.Join(m.config.CachePath, fileName)
buffer := bytes.NewBuffer(pdbData)

return []ipfs.IPFSFile{{Name: "result.pdb", Path: path, Buffer: buffer}}, nil
func (m *ESMFoldV1Model) GetFiles(ctx context.Context, gpu *common.GPU, taskid string, input interface{}) ([]ipfs.IPFSFile, error) {
if err := ctx.Err(); err != nil {
m.logger.Warn().Err(err).Str("task", taskid).Msg("Context canceled before GetFiles execution")
return nil, err
}

inner, ok := input.(ESMFoldV1Prompt)
if !ok {
return nil, fmt.Errorf("invalid input type: expected ESMFoldV1Prompt")
}

payload := map[string]string{"sequence": inner.Input.Prompt}
marshaledInput, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("failed to marshal input: %w", err)
}

m.logger.Debug().Str("marshaledInput", string(marshaledInput)).Msg("Prepared input payload")
m.logger.Debug().Str("gpu_url", gpu.Url).Msg("GPU URL")

endpoint := fmt.Sprintf("%s/predict", strings.TrimSuffix(gpu.Url, "/"))
m.logger.Debug().Str("endpoint", endpoint).Msg("Constructed endpoint URL")

tmpFile, err := os.CreateTemp("", "curl_output_*.pdb")
if err != nil {
m.logger.Error().Err(err).Str("task", taskid).Msg("Failed to create temporary file")
return nil, fmt.Errorf("failed to create temporary file: %w", err)
}
tmpFileName := tmpFile.Name()
tmpFile.Close()

curlArgs := []string{
"-v",
"-k",
"-X", "POST",
"-H", "Content-Type: application/json",
"-H", "Accept: chemical/x-pdb",
"-d", string(marshaledInput),
"-o", tmpFileName,
endpoint,
}

m.logger.Debug().Str("curl_command", fmt.Sprintf("curl %s", strings.Join(curlArgs, " "))).Msg("Executing curl")

cmd := exec.CommandContext(ctx, "curl", curlArgs...)
var stderr bytes.Buffer
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
m.logger.Error().
Err(err).
Str("task", taskid).
Str("gpu", endpoint).
Str("stderr", stderr.String()).
Msg("Failed to execute curl")
os.Remove(tmpFileName)
return nil, fmt.Errorf("failed to execute curl: %w, stderr: %s", err, stderr.String())
}

pdbData, err := os.ReadFile(tmpFileName)
if err != nil {
m.logger.Error().Err(err).Str("task", taskid).Msg("Failed to read curl output file")
os.Remove(tmpFileName)
return nil, fmt.Errorf("failed to read curl output file: %w", err)
}
os.Remove(tmpFileName)

if len(pdbData) == 0 {
m.logger.Error().Str("task", taskid).Msg("curl returned empty PDB data")
return nil, errors.New("curl returned empty PDB data")
}

m.logger.Debug().Int("data_length", len(pdbData)).Msg("Received curl response")

fileName := fmt.Sprintf("%d.%s.pdb", gpu.ID, uuid.New().String())
path := filepath.Join(m.config.CachePath, fileName)
buffer := bytes.NewBuffer(pdbData)

return []ipfs.IPFSFile{{Name: "result.pdb", Path: path, Buffer: buffer}}, nil
}

func (m *ESMFoldV1Model) GetCID(ctx context.Context, gpu *common.GPU, taskid string, input any) ([]byte, error) {
Expand Down