adding getsourcefunction sha256 check

This commit is contained in:
2025-10-31 21:32:35 -03:00
parent cbea1dd8b5
commit ff5e271195
3 changed files with 129 additions and 20 deletions

View File

@@ -1,24 +1,39 @@
package packet package packet
type Config struct { type Config struct {
BinDir string BinDir *string
} }
const defaultBinDir = "/usr/bin" const defaultBinDir = "/usr/bin"
func checkConfig(cfg *Config) *Config { func checkConfig(cfg *Config) *Config {
if cfg == nil { if cfg == nil {
bin := defaultBinDir
return &Config{ return &Config{
BinDir: defaultBinDir, BinDir: &bin,
} }
} }
if cfg.BinDir == "" { if *cfg.BinDir == "" || cfg.BinDir == nil {
bin := defaultBinDir
return &Config{ return &Config{
BinDir: defaultBinDir, BinDir: &bin,
} }
} else { } else {
return cfg return cfg
} }
} }
func checkConfigSrc(cfg *GetSourceConfig) *GetSourceConfig {
if cfg == nil {
return nil
}
switch {
case *cfg.PacketDir == "" || cfg.PacketDir == nil:
s := randStringBytes(12)
cfg.PacketDir = &s
}
return cfg
}

View File

@@ -2,14 +2,20 @@ package packet
import ( import (
"archive/tar" "archive/tar"
"bytes"
"crypto/sha256"
"encoding/hex"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net/http"
"path/filepath" "path/filepath"
"runtime" "runtime"
"time"
"github.com/klauspost/compress/zstd" "github.com/klauspost/compress/zstd"
lua_utils "github.com/roboogg133/packets/internal/lua"
lua "github.com/yuin/gopher-lua" lua "github.com/yuin/gopher-lua"
) )
@@ -28,12 +34,13 @@ type PacketLua struct {
Build *lua.LFunction Build *lua.LFunction
Install *lua.LFunction Install *lua.LFunction
PreRemove *lua.LFunction
} }
type Source struct { type Source struct {
Method string Method string
Url string Url string
Specs interface{} Specs any
} }
type VersionConstraint string type VersionConstraint string
@@ -70,6 +77,8 @@ type GETSpecs struct {
var ErrCantFindPacketDotLua = errors.New("can't find Packet.lua in .tar.zst file") var ErrCantFindPacketDotLua = errors.New("can't find Packet.lua in .tar.zst file")
var ErrFileDontReturnTable = errors.New("invalid Packet.lua format: the file do not return a table") var ErrFileDontReturnTable = errors.New("invalid Packet.lua format: the file do not return a table")
var ErrCannotFindPackageTable = errors.New("invalid Packet.lua format: can't find package table") var ErrCannotFindPackageTable = errors.New("invalid Packet.lua format: can't find package table")
var ErrInstallFunctionDoesNotExist = errors.New("can not find instal()")
var ErrSha256Sum = errors.New("false checksum")
// ReadPacket read a Packet.lua and alredy set global vars // ReadPacket read a Packet.lua and alredy set global vars
func ReadPacket(f []byte, cfg *Config) (PacketLua, error) { func ReadPacket(f []byte, cfg *Config) (PacketLua, error) {
@@ -79,15 +88,19 @@ func ReadPacket(f []byte, cfg *Config) (PacketLua, error) {
defer L.Close() defer L.Close()
osObject := L.GetGlobal("os").(*lua.LTable) osObject := L.GetGlobal("os").(*lua.LTable)
osObject.RawSetString("setenv", L.NewFunction(lua_utils.LSetEnv))
ioObject := L.GetGlobal("io").(*lua.LTable) ioObject := L.GetGlobal("io").(*lua.LTable)
L.SetGlobal("os", lua.LNil) L.SetGlobal("os", lua.LNil)
L.SetGlobal("io", lua.LNil) L.SetGlobal("io", lua.LNil)
L.SetGlobal("BIN_DIR", lua.LString(cfg.BinDir)) L.SetGlobal("BIN_DIR", lua.LString(*cfg.BinDir))
L.SetGlobal("CURRENT_ARCH", lua.LString(runtime.GOARCH)) L.SetGlobal("CURRENT_ARCH", lua.LString(runtime.GOARCH))
L.SetGlobal("CURRENT_ARCH_NORMALIZED", lua.LString(normalizeArch(runtime.GOARCH)))
L.SetGlobal("CURRENT_PLATAFORM", lua.LString(runtime.GOOS)) L.SetGlobal("CURRENT_PLATAFORM", lua.LString(runtime.GOOS))
L.SetGlobal("pathjoin", L.NewFunction(lua_utils.Ljoin))
if err := L.DoString(string(f)); err != nil { if err := L.DoString(string(f)); err != nil {
return PacketLua{}, err return PacketLua{}, err
} }
@@ -123,10 +136,11 @@ func ReadPacket(f []byte, cfg *Config) (PacketLua, error) {
Build: getFunctionFromTable(table, "build"), Build: getFunctionFromTable(table, "build"),
Install: getFunctionFromTable(table, "install"), Install: getFunctionFromTable(table, "install"),
PreRemove: getFunctionFromTable(table, "pre_remove"),
} }
if packetLua.Install == nil { if packetLua.Install == nil {
return PacketLua{}, fmt.Errorf("install() does not exist") return PacketLua{}, ErrInstallFunctionDoesNotExist
} }
return *packetLua, nil return *packetLua, nil
@@ -164,3 +178,90 @@ func ReadPacketFromZSTDF(file io.Reader, cfg *Config) (PacketLua, error) {
} }
return PacketLua{}, ErrCantFindPacketDotLua return PacketLua{}, ErrCantFindPacketDotLua
} }
type GetSourceConfig struct {
PacketDir *string
}
func GetSource(url, method string, info any) ([]byte, error) {
switch method {
case "GET":
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
specs := info.(GETSpecs)
for k, v := range *specs.Headers {
req.Header.Set(k, v)
}
client := http.Client{Timeout: 5 * time.Minute}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if !verifySHA256(*specs.SHA256, data) {
return nil, ErrSha256Sum
}
return data, nil
case "POST":
specs := info.(POSTSpecs)
var body *bytes.Reader
if specs.Body != nil {
body = bytes.NewReader([]byte(*specs.Body))
} else {
body = nil
}
req, err := http.NewRequest("POST", url, body)
if err != nil {
return nil, err
}
for k, v := range *specs.Headers {
req.Header.Set(k, v)
}
client := http.Client{Timeout: 5 * time.Minute}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if !verifySHA256(*specs.SHA256, data) {
return nil, ErrSha256Sum
}
return data, nil
}
return nil, fmt.Errorf("invalid method")
}
func verifySHA256(checksum string, src []byte) bool {
check := sha256.Sum256(src)
return hex.EncodeToString(check[:]) == checksum
}

View File

@@ -1,9 +1,5 @@
package packet package packet
import (
"fmt"
)
func (pkg PacketLua) IsValid() bool { func (pkg PacketLua) IsValid() bool {
var a, b int var a, b int
@@ -15,8 +11,7 @@ func (pkg PacketLua) IsValid() bool {
a += len(*pkg.GlobalSources) a += len(*pkg.GlobalSources)
if a <= 0 || b <= 0 { if a < 1 || len(*pkg.Plataforms) > b {
fmt.Println("invalid")
return false return false
} }
@@ -26,7 +21,5 @@ func (pkg PacketLua) IsValid() bool {
case pkg.Description == "" || pkg.Maintaner == "" || pkg.Name == "" || pkg.Version == "": case pkg.Description == "" || pkg.Maintaner == "" || pkg.Name == "" || pkg.Version == "":
return false return false
} }
fmt.Println("valid")
return true return true
} }