From b13f0ab3c814efbed5ec704466ac56dce6aec4af Mon Sep 17 00:00:00 2001 From: Cody Date: Tue, 24 Jun 2025 20:13:45 -0400 Subject: [PATCH] Initial work to support ESMFold model --- README.md | 2 +- config/config.testnet.json | 12 ++ models/esmfold.go | 261 +++++++++++++++++++++++++++++++++++++ models/models.go | 60 +++++---- 4 files changed, 307 insertions(+), 28 deletions(-) create mode 100644 models/esmfold.go diff --git a/README.md b/README.md index f77dfc3..5d6c350 100644 --- a/README.md +++ b/README.md @@ -186,7 +186,7 @@ Basic usage: | `--config` | Path to the configuration file | "config.json" | No | | `--skipvalidation` | Skip safety checks and validation of the model and miner version | false | No | | `--loglevel` | Set logging verbosity (1 = default) | 1 | No | -| `--testnet` | Run using testnet (1 = local, 2 = nova testnet) | 0 | No | +| `--testnet` | Run using testnet (1 = local, 2 = Arbitrum Sepolia testnet) | 0 | No | | `--taskscanner` | Scan blocks for unsolved tasks in pst 12 hours | 0 | No | #### Example Commands diff --git a/config/config.testnet.json b/config/config.testnet.json index 0e60866..cac27c8 100644 --- a/config/config.testnet.json +++ b/config/config.testnet.json @@ -27,6 +27,18 @@ "rate": "0", "cid": "0x12208dc17f4317285392e6811ca31c9efeb3b8511ebac1bd14b935ecf2c1b60c1b09" } + }, + "esmfold-v1": { + "id": "0xbf632554a7a1c5162ef6150b17a79f30b90850f7e0c56391da0ec810306d021b", + "mineable": true, + "contracts": {}, + "params": { + "addr": "0xD7d5bCaC95d26aB06333238E57219B15270DFc8d", + "fee": "0", + "rate": "0", + "cid": "0x12200e4c49b7ee8a0955a773e657444880712f639341237ad63a5f59303a343f4e4e" + } } + } } diff --git a/models/esmfold.go b/models/esmfold.go new file mode 100644 index 0000000..d50232b --- /dev/null +++ b/models/esmfold.go @@ -0,0 +1,261 @@ +package models + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "gobius/common" + "gobius/config" + "gobius/ipfs" + "gobius/utils" + "io" + "net/http" + "path/filepath" + "strings" + "time" + + "github.com/google/uuid" + "github.com/mr-tron/base58" + "github.com/rs/zerolog" +) + +type ESMFoldV1Input struct { + Sequence string `json:"sequence"` +} + +type ESMFoldV1Model struct { + Model + timeoutDuration time.Duration + ipfsTimeoutDuration time.Duration + Filters []MiningFilter + config *config.AppConfig + client *http.Client + logger zerolog.Logger + ipfs ipfs.IPFSClient +} + +var _ ModelInterface = (*ESMFoldV1Model)(nil) + +var ESMFoldV1ModelTemplate = Model{ + ID: "", + Mineable: true, + Template: map[string]any{ + "meta": map[string]any{ + "title": "ESMFold v1", + "description": "Protein folding model from Facebook Research based on ESM architecture", + "git": "https://github.com/facebookresearch/esm", + "docker": "https://hub.docker.com/r/cody8295/esmfold-arbius", + "version": 1, + }, + "input": []map[string]any{ + { + "variable": "sequence", + "type": "string", + "required": true, + "default": "", + "description": "Amino acid sequence to predict protein structure from", + }, + }, + "output": []map[string]any{ + { + "filename": "result.pdb", + "type": "text", + }, + }, + }, +} + +func NewESMFoldV1Model(client ipfs.IPFSClient, appConfig *config.AppConfig, logger zerolog.Logger) *ESMFoldV1Model { + model, ok := appConfig.BaseConfig.Models["esmfold-v1"] + if !ok { + return nil + } + + if model.ID == "" { + logger.Error().Str("model", "esmfold-v1").Msg("model ID is empty") + return nil + } + + http := &http.Client{ + Transport: &http.Transport{MaxIdleConnsPerHost: 10}, + } + + timeout := 600 * time.Second + ipfsTimeout := 30 * time.Second + + m := &ESMFoldV1Model{ + Model: ESMFoldV1ModelTemplate, + timeoutDuration: timeout, + ipfsTimeoutDuration: ipfsTimeout, + config: appConfig, + Filters: []MiningFilter{ + { + MinFee: 0, + MinTime: 0, + }, + }, + ipfs: client, + client: http, + logger: logger, + } + m.Model.ID = model.ID + return m +} + +func (m *ESMFoldV1Model) HydrateInput(preprocessedInput map[string]any, seed uint64) (InputHydrationResult, error) { + input := make(map[string]any) + + templateMeta, ok := m.Model.Template.(map[string]any) + if !ok { + return nil, fmt.Errorf("invalid template format") + } + + meta, ok := templateMeta["meta"].(map[string]any) + if !ok { + return nil, fmt.Errorf("invalid meta format in template") + } + + inputFields, ok := meta["input"].([]map[string]any) + if !ok { + return nil, fmt.Errorf("invalid input format in template meta") + } + + for _, field := range inputFields { + varName := field["variable"].(string) + fieldType := field["type"].(string) + required, _ := field["required"].(bool) + + value, exists := preprocessedInput[varName] + if required && !exists { + return nil, fmt.Errorf("input missing required field (%s)", varName) + } + + if exists { + if err := validateType(value, fieldType, varName); err != nil { + return nil, err + } + input[varName] = value + } else { + input[varName] = field["default"] + } + } + + var inner ESMFoldV1Input + inner.Sequence, _ = input["sequence"].(string) + return inner, nil +} + +func (m *ESMFoldV1Model) GetID() string { + return m.Model.ID +} + +func (m *ESMFoldV1Model) GetFiles(ctx context.Context, gpu *common.GPU, taskid string, input any) ([]ipfs.IPFSFile, error) { + if err := ctx.Err(); err != nil { + m.logger.Warn().Err(err).Str("task", taskid).Msg("Context canceled before GetFiles execution") + return nil, err + } + + inner, ok := input.(ESMFoldV1Input) + if !ok { + return nil, fmt.Errorf("invalid input type: expected ESMFoldV1Input") + } + + payload := map[string]string{"sequence": inner.Sequence} + marshaledInput, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("failed to marshal input: %w", err) + } + + endpoint := fmt.Sprintf("%s:8080/predict", strings.TrimSuffix(gpu.Url, "/")) + req, err := http.NewRequestWithContext(ctx, "POST", endpoint, bytes.NewBuffer(marshaledInput)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := m.client.Do(req) + if err != nil { + if errors.Is(err, context.DeadlineExceeded) { + m.logger.Error().Err(err).Str("task", taskid).Str("gpu", endpoint).Msg("model inference request timed out") + return nil, fmt.Errorf("model inference timed out: %w", err) + } + return nil, fmt.Errorf("failed to POST to GPU: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + if resp.StatusCode == http.StatusConflict { + m.logger.Warn().Str("task", taskid).Str("gpu", endpoint).Int("status", resp.StatusCode).Str("body", string(bodyBytes)).Msg("resource busy") + return nil, ErrResourceBusy + } + return nil, fmt.Errorf("server returned non-200 status: %d - %s", resp.StatusCode, string(bodyBytes)) + } + + pdbData, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read model response body: %w", err) + } + + if len(pdbData) == 0 { + return nil, errors.New("model returned empty PDB data") + } + + fileName := fmt.Sprintf("%d.%s.pdb", gpu.ID, uuid.New().String()) + path := filepath.Join(m.config.CachePath, fileName) + buffer := bytes.NewBuffer(pdbData) + + return []ipfs.IPFSFile{{Name: "result.pdb", Path: path, Buffer: buffer}}, nil +} + +func (m *ESMFoldV1Model) GetCID(ctx context.Context, gpu *common.GPU, taskid string, input any) ([]byte, error) { + if err := ctx.Err(); err != nil { + return nil, fmt.Errorf("parent context canceled before GetCID: %w", err) + } + + timeoutCtx, cancel := context.WithTimeout(ctx, m.timeoutDuration) + defer cancel() + + paths, err := utils.ExpRetryWithContext(timeoutCtx, m.logger, func() (any, error) { + return m.GetFiles(timeoutCtx, gpu, taskid, input) + }, 3, 1000) + if err != nil { + if errors.Is(err, ErrResourceBusy) { + m.logger.Warn().Str("task", taskid).Str("gpu", gpu.Url).Msg("GPU remained busy after retries") + } + return nil, err + } + + ipfsCtx, ipfsCancel := context.WithTimeout(ctx, m.ipfsTimeoutDuration) + defer ipfsCancel() + + cid58, err := utils.ExpRetryWithContext(ipfsCtx, m.logger, func() (any, error) { + return m.ipfs.PinFilesToIPFS(ipfsCtx, taskid, paths.([]ipfs.IPFSFile)) + }, 3, 1000) + if err != nil { + return nil, fmt.Errorf("failed to pin files to IPFS after retries: %w", err) + } + cidBytes, err := base58.Decode(cid58.(string)) + if err != nil { + return nil, fmt.Errorf("failed to decode base58 CID string: %w", err) + } + + return cidBytes, nil +} + +func (m *ESMFoldV1Model) Validate(gpu *common.GPU, taskid string) error { + // testInput := ESMFoldV1Input{ + // Sequence: "MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQAPILSRVGDGTQDNLSGAEKAVQVKVKALPDAQFKKVVK", + // } + + // cid, err := m.GetCID(context.Background(), gpu, "startup-test-taskid", testInput) + // if err != nil { + // return fmt.Errorf("validation failed: %w", err) + // } + + //cidStr := "0x" + // TODO: Implement Validation here + + return nil +} \ No newline at end of file diff --git a/models/models.go b/models/models.go index ef26753..2f464f2 100644 --- a/models/models.go +++ b/models/models.go @@ -78,33 +78,39 @@ var ModelRegistry *ModelFactory func InitModelRegistry(client ipfs.IPFSClient, config *config.AppConfig, logger zerolog.Logger) { ModelRegistry = NewModelFactory() - // Register Qwen Mainnet - modelQwenMainnet := NewQwenMainnetModel(client, config, logger) - if modelQwenMainnet != nil { - ModelRegistry.RegisterModel(modelQwenMainnet) - } - - modelWaiV120Mainnet := NewWaiV120MainnetModel(client, config, logger) - if modelWaiV120Mainnet != nil { - ModelRegistry.RegisterModel(modelWaiV120Mainnet) - } - - // Sepolia testnet model - modelQwenTest := NewQwenTestModel(client, config, logger) - if modelQwenTest != nil { - ModelRegistry.RegisterModel(modelQwenTest) - } - - // Register Kandinsky2 - modelKandinsky2 := NewKandinsky2Model(client, config, logger) - if modelKandinsky2 != nil { - ModelRegistry.RegisterModel(modelKandinsky2) - } - - // Register Metabaron-Uncensored-8B - modelMetabaron := NewMetabaronModel(client, config, logger) - if modelMetabaron != nil { - ModelRegistry.RegisterModel(modelMetabaron) + // // Register Qwen Mainnet + // modelQwenMainnet := NewQwenMainnetModel(client, config, logger) + // if modelQwenMainnet != nil { + // ModelRegistry.RegisterModel(modelQwenMainnet) + // } + + // modelWaiV120Mainnet := NewWaiV120MainnetModel(client, config, logger) + // if modelWaiV120Mainnet != nil { + // ModelRegistry.RegisterModel(modelWaiV120Mainnet) + // } + + // // Sepolia testnet model + // modelQwenTest := NewQwenTestModel(client, config, logger) + // if modelQwenTest != nil { + // ModelRegistry.RegisterModel(modelQwenTest) + // } + + // // Register Kandinsky2 + // modelKandinsky2 := NewKandinsky2Model(client, config, logger) + // if modelKandinsky2 != nil { + // ModelRegistry.RegisterModel(modelKandinsky2) + // } + + // // Register Metabaron-Uncensored-8B + // modelMetabaron := NewMetabaronModel(client, config, logger) + // if modelMetabaron != nil { + // ModelRegistry.RegisterModel(modelMetabaron) + // } + + // Register ESMFold V1 + modelESMFoldV1 := NewESMFoldV1Model(client, config, logger) + if modelESMFoldV1 != nil { + ModelRegistry.RegisterModel(modelESMFoldV1) } }