code auth added
This commit is contained in:
parent
f6576a9134
commit
9f655ba8c0
5 changed files with 127 additions and 145 deletions
195
node.go
195
node.go
|
@ -2,8 +2,7 @@ package main
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
|
@ -11,6 +10,10 @@ import (
|
|||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
libp2p "github.com/libp2p/go-libp2p"
|
||||
crypto "github.com/libp2p/go-libp2p-core/crypto"
|
||||
"github.com/libp2p/go-libp2p-core/peer"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -18,6 +21,7 @@ var (
|
|||
peers []string
|
||||
authMutex sync.Mutex
|
||||
authenticated = make(map[string]bool)
|
||||
hostID peer.ID
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
|
@ -35,74 +39,99 @@ type CrawlerConfig struct {
|
|||
|
||||
func loadNodeConfig() {
|
||||
config := loadConfig()
|
||||
authCode = config.AuthCode // nuh uh
|
||||
authCode = config.AuthCode
|
||||
peers = config.Peers
|
||||
}
|
||||
|
||||
func initP2P() (peer.ID, error) {
|
||||
priv, _, err := crypto.GenerateKeyPairWithReader(crypto.Ed25519, 2048, rand.Reader)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate key pair: %v", err)
|
||||
}
|
||||
|
||||
h, err := libp2p.New(libp2p.Identity(priv))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create libp2p host: %v", err)
|
||||
}
|
||||
|
||||
return h.ID(), nil
|
||||
}
|
||||
|
||||
func sendMessage(serverAddr string, msg Message) error {
|
||||
msgBytes, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal message: %v", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", serverAddr, bytes.NewBuffer(msgBytes))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %v", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", authCode)
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: time.Second * 10,
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := ioutil.ReadAll(resp.Body)
|
||||
return fmt.Errorf("server error: %s", body)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleNodeRequest(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Invalid request method", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "Error reading request body", http.StatusInternalServerError)
|
||||
auth := r.Header.Get("Authorization")
|
||||
if auth != authCode {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
var msg Message
|
||||
if err := json.Unmarshal(body, &msg); err != nil {
|
||||
http.Error(w, "Error parsing JSON", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if !isAuthenticated(msg.ID) {
|
||||
http.Error(w, "Authentication required", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
interpretMessage(msg)
|
||||
fmt.Fprintln(w, "Message received")
|
||||
}
|
||||
|
||||
func handleAuth(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Invalid request method", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
err := json.NewDecoder(r.Body).Decode(&msg)
|
||||
if err != nil {
|
||||
http.Error(w, "Error reading request body", http.StatusInternalServerError)
|
||||
http.Error(w, "Error parsing JSON", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
var authRequest CrawlerConfig
|
||||
if err := json.Unmarshal(body, &authRequest); err != nil {
|
||||
http.Error(w, "Error parsing JSON", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
log.Printf("Received message: %+v\n", msg)
|
||||
w.Write([]byte("Message received"))
|
||||
|
||||
expectedCode := GenerateRegistrationCode(authRequest.Host, authRequest.Port, authCode)
|
||||
if authRequest.AuthCode != expectedCode {
|
||||
http.Error(w, "Invalid auth code", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
authMutex.Lock()
|
||||
authenticated[authRequest.ID] = true
|
||||
authMutex.Unlock()
|
||||
|
||||
fmt.Fprintln(w, "Authenticated successfully")
|
||||
interpretMessage(msg)
|
||||
}
|
||||
|
||||
func isAuthenticated(id string) bool {
|
||||
authMutex.Lock()
|
||||
defer authMutex.Unlock()
|
||||
return authenticated[id]
|
||||
func startNodeClient() {
|
||||
for {
|
||||
for _, peerAddr := range peers {
|
||||
msg := Message{
|
||||
ID: hostID.Pretty(),
|
||||
Type: "test",
|
||||
Content: "This is a test message from the client node",
|
||||
}
|
||||
|
||||
err := sendMessage(peerAddr, msg)
|
||||
if err != nil {
|
||||
log.Printf("Error sending message to %s: %v", peerAddr, err)
|
||||
} else {
|
||||
log.Println("Message sent successfully to", peerAddr)
|
||||
}
|
||||
}
|
||||
time.Sleep(10 * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func interpretMessage(msg Message) {
|
||||
|
@ -120,71 +149,3 @@ func interpretMessage(msg Message) {
|
|||
fmt.Println("Received unknown message type:", msg.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func sendMessage(address, id, msgType, content string) error {
|
||||
msg := Message{
|
||||
ID: id,
|
||||
Type: msgType,
|
||||
Content: content,
|
||||
}
|
||||
msgBytes, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("%s/node", address), bytes.NewBuffer(msgBytes))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := ioutil.ReadAll(resp.Body)
|
||||
return fmt.Errorf("failed to send message: %s", body)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func startNodeClient(addresses []string) {
|
||||
for _, address := range addresses {
|
||||
go func(addr string) {
|
||||
for {
|
||||
err := sendMessage(addr, authCode, "test", "This is a test message")
|
||||
if err != nil {
|
||||
fmt.Println("Error sending test message to", addr, ":", err)
|
||||
continue
|
||||
}
|
||||
time.Sleep(10 * time.Second)
|
||||
}
|
||||
}(address)
|
||||
}
|
||||
}
|
||||
|
||||
func GenerateRegistrationCode(host string, port int, authCode string) string {
|
||||
data := fmt.Sprintf("%s:%d:%s", host, port, authCode)
|
||||
hash := sha256.Sum256([]byte(data))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
func ParseRegistrationCode(code string, host string, port int, authCode string) (string, int, string, error) {
|
||||
data := fmt.Sprintf("%s:%d:%s", host, port, authCode)
|
||||
hash := sha256.Sum256([]byte(data))
|
||||
expectedCode := hex.EncodeToString(hash[:])
|
||||
|
||||
log.Printf("Parsing registration code: %s", code)
|
||||
log.Printf("Expected registration code: %s", expectedCode)
|
||||
|
||||
if expectedCode != code {
|
||||
return "", 0, "", fmt.Errorf("invalid registration code")
|
||||
}
|
||||
|
||||
return host, port, authCode, nil
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue