A More Secure Internet Connection for Your Home https://fen.gg
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

328 lines
8.2 KiB

  1. package model
  2. //
  3. // Fengg Security Gateway Server Application
  4. // Copyright (C) 2020 Lukas Matt <support@fen.gg>
  5. //
  6. // This program is free software: you can redistribute it and/or modify
  7. // it under the terms of the GNU General Public License as published by
  8. // the Free Software Foundation, either version 3 of the License.
  9. //
  10. // This program is distributed in the hope that it will be useful,
  11. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. // GNU General Public License for more details.
  14. //
  15. // You should have received a copy of the GNU General Public License
  16. // along with this program. If not, see <https://www.gnu.org/licenses/>.
  17. //
  18. import (
  19. "sync"
  20. "errors"
  21. "time"
  22. "fmt"
  23. "github.com/jmoiron/sqlx"
  24. _ "github.com/lib/pq"
  25. )
  26. const (
  27. ConnectionResetCounterTmpl = `UPDATE connections SET count=0;`
  28. ConnectionNamedInsertTmpl = `
  29. INSERT INTO connections (
  30. updated_at, transport_protocol_id, application_protocol_id,
  31. destination_id, source_id, count
  32. ) VALUES (
  33. :updated_at, :transport_protocol_id, :application_protocol_id,
  34. :destination_id, :source_id, :count
  35. ) RETURNING id`
  36. ConnectionNamedUpdateTmpl = `
  37. UPDATE connections
  38. SET
  39. updated_at=now(), %s
  40. WHERE destination_id=:destination_id
  41. AND source_id=:source_id
  42. AND transport_protocol_id=:transport_protocol_id
  43. AND application_protocol_id=:application_protocol_id;`
  44. ConnectionCountryQueryTmpl = `
  45. SELECT CONCAT(i1.country, i2.country) as mcountry
  46. FROM connections AS c
  47. INNER JOIN ip_information AS i1 ON c.source_id=i1.id
  48. INNER JOIN ip_information AS i2 ON c.destination_id=i2.id
  49. WHERE (
  50. i1.country!='' or i2.country!= ''
  51. ) %s GROUP BY mcountry;`
  52. ConnectionQueryTmpl = `SELECT * FROM connections WHERE %s;`
  53. )
  54. type Connection struct {
  55. ID uint `db:"id" json:"id"`
  56. UpdatedAt time.Time `db:"updated_at" json:"updatedAt"`
  57. Count uint `db:"count" json:"count"`
  58. SourceID uint `db:"source_id" json:"sourceId"`
  59. DestinationID uint `db:"destination_id" json:"destinationId"`
  60. TransportProtocolID uint `db:"transport_protocol_id" json:"transportProtocolId"`
  61. ApplicationProtocolID uint `db:"application_protocol_id" json:"applicationProtocolId"`
  62. Source IPInformation `db:"-" json:"source"`
  63. Destination IPInformation `db:"-" json:"destination"`
  64. TransportProtocol Protocol `db:"-" json:"transportProtocol"`
  65. ApplicationProtocol Protocol `db:"-" json:"applicationProtocol"`
  66. }
  67. type Connections []Connection
  68. func NewConnection() *Connection {
  69. return &Connection{
  70. UpdatedAt: time.Now(),
  71. Count: 1,
  72. }
  73. }
  74. func (connections Connections) PreLoad() error {
  75. for _, connection := range connections {
  76. err := connection.PreLoad()
  77. if err != nil {
  78. return err
  79. }
  80. }
  81. return nil
  82. }
  83. // PreLoad will fetch all extra information like e.g.
  84. // IP information and write them to the connection struct
  85. func (a *Connection) PreLoad() error {
  86. a.Source.ID = a.SourceID
  87. err := a.Source.Find()
  88. if err != nil {
  89. return err
  90. }
  91. // if all IDs are zero the Find function can
  92. // try looking for a record via the IP address
  93. // or in case of a protocol the name attribute.
  94. // Then we have the actually ID of the record
  95. // within the nested struct which we should populate
  96. // to the actual connection struct.
  97. a.SourceID = a.Source.ID
  98. a.Destination.ID = a.DestinationID
  99. err = a.Destination.Find()
  100. if err != nil {
  101. return err
  102. }
  103. a.DestinationID = a.Destination.ID
  104. a.TransportProtocol.ID = a.TransportProtocolID
  105. err = a.TransportProtocol.Find()
  106. if err != nil {
  107. return err
  108. }
  109. a.TransportProtocolID = a.TransportProtocol.ID
  110. a.ApplicationProtocol.ID = a.ApplicationProtocolID
  111. err = a.ApplicationProtocol.Find()
  112. if err != nil {
  113. return err
  114. }
  115. a.ApplicationProtocolID = a.ApplicationProtocol.ID
  116. return nil
  117. }
  118. func (a *Connection) Valid() bool {
  119. return a.Source.IP.Raw != nil && a.Destination.IP.Raw != nil
  120. }
  121. func (a *Connection) Exists() bool {
  122. db, err := sqlx.Connect(dbDriver, dbConnect)
  123. if err != nil {
  124. logger.Error().Err(err).Msg("cannot connect to database")
  125. return false
  126. }
  127. defer db.Close()
  128. err = a.PreLoad()
  129. if err != nil {
  130. logger.Warn().Err(err).Msg("cannot preload connection")
  131. }
  132. if a.ID > 0 {
  133. err = db.Get(a, fmt.Sprintf(ConnectionQueryTmpl, `id=$1`), a.ID)
  134. } else {
  135. err = db.Get(a, fmt.Sprintf(ConnectionQueryTmpl, `
  136. destination_id=$1
  137. AND source_id=$2
  138. AND application_protocol_id=$3
  139. AND transport_protocol_id=$4`,
  140. ), a.DestinationID, a.SourceID, a.ApplicationProtocolID, a.TransportProtocolID)
  141. }
  142. if err != nil {
  143. logger.Debug().Err(err).Msg("cannot find connection entry")
  144. }
  145. return err == nil
  146. }
  147. func (a *Connection) Create() error {
  148. if !a.Valid() {
  149. return errors.New("connection is invalid")
  150. }
  151. a.LookupExtraInformation()
  152. err := a.Source.CreateOrUpdate()
  153. if err != nil {
  154. return err
  155. }
  156. a.SourceID = a.Source.ID
  157. err = a.Destination.CreateOrUpdate()
  158. if err != nil {
  159. return err
  160. }
  161. a.DestinationID = a.Destination.ID
  162. err = a.TransportProtocol.CreateIfNotExists()
  163. if err != nil {
  164. return err
  165. }
  166. a.TransportProtocolID = a.TransportProtocol.ID
  167. err = a.ApplicationProtocol.CreateIfNotExists()
  168. if err != nil {
  169. return err
  170. }
  171. a.ApplicationProtocolID = a.ApplicationProtocol.ID
  172. db, err := sqlx.Connect(dbDriver, dbConnect)
  173. if err != nil {
  174. return err
  175. }
  176. defer db.Close()
  177. _, err = db.NamedExec(`INSERT INTO connections (
  178. updated_at, transport_protocol_id, application_protocol_id,
  179. destination_id, source_id, count
  180. ) VALUES (
  181. :updated_at, :transport_protocol_id, :application_protocol_id,
  182. :destination_id, :source_id, :count
  183. ) RETURNING id`, a)
  184. return err
  185. }
  186. func (a *Connection) Update() error {
  187. if !a.Valid() {
  188. return errors.New("connection is invalid")
  189. }
  190. twoHourCache := a.UpdatedAt.Add(2 * time.Hour)
  191. if twoHourCache.Sub(time.Now()) < 0 {
  192. // lookup names every two hours
  193. a.LookupExtraInformation()
  194. }
  195. err := a.Source.CreateOrUpdate()
  196. if err != nil {
  197. return err
  198. }
  199. a.SourceID = a.Source.ID
  200. err = a.Destination.CreateOrUpdate()
  201. if err != nil {
  202. return err
  203. }
  204. a.DestinationID = a.Destination.ID
  205. err = a.TransportProtocol.CreateIfNotExists()
  206. if err != nil {
  207. return err
  208. }
  209. a.TransportProtocolID = a.TransportProtocol.ID
  210. err = a.ApplicationProtocol.CreateIfNotExists()
  211. if err != nil {
  212. return err
  213. }
  214. a.ApplicationProtocolID = a.ApplicationProtocol.ID
  215. db, err := sqlx.Connect(dbDriver, dbConnect)
  216. if err != nil {
  217. return err
  218. }
  219. defer db.Close()
  220. _, err = db.NamedExec(fmt.Sprintf(ConnectionNamedUpdateTmpl, "count=:count"), a)
  221. return err
  222. }
  223. func (a *Connection) LookupExtraInformation() {
  224. var waitGroup sync.WaitGroup
  225. waitGroup.Add(6)
  226. // lookup GeoIP and rDNS for the source IP
  227. go func() {
  228. defer waitGroup.Done()
  229. a.Source.LookupRDNSAndGeoIP()
  230. }()
  231. // lookup MAC address for the source IP
  232. go func() {
  233. defer waitGroup.Done()
  234. err := a.Source.LookupMAC()
  235. if err != nil {
  236. logger.Debug().Err(err).Msg("cannot find MAC address")
  237. }
  238. }()
  239. // lookup NetBIOS for the source IP
  240. go func() {
  241. defer waitGroup.Done()
  242. err := a.Source.LookupNetBios()
  243. if err != nil {
  244. logger.Debug().Err(err).Msg("cannot find NetBIOS name")
  245. }
  246. }()
  247. // lookup GeoIP and rDNS for the destination IP
  248. go func() {
  249. defer waitGroup.Done()
  250. a.Destination.LookupRDNSAndGeoIP()
  251. }()
  252. // lookup MAC address for the destination IP
  253. go func() {
  254. defer waitGroup.Done()
  255. err := a.Destination.LookupMAC()
  256. if err != nil {
  257. logger.Debug().Err(err).Msg("cannot find MAC address")
  258. }
  259. }()
  260. // lookup NetBIOS for the destination IP
  261. go func() {
  262. defer waitGroup.Done()
  263. err := a.Destination.LookupNetBios()
  264. if err != nil {
  265. logger.Debug().Err(err).Msg("cannot find MAC address")
  266. }
  267. }()
  268. waitGroup.Wait()
  269. }
  270. func (a *Connections) FindRecentConnections() error {
  271. db, err := sqlx.Connect(dbDriver, dbConnect)
  272. if err != nil {
  273. return err
  274. }
  275. defer db.Close()
  276. err = db.Select(a, fmt.Sprintf(ConnectionQueryTmpl, `count > 0`))
  277. if err != nil {
  278. return err
  279. }
  280. // reset the counter immediately
  281. db.MustExec(ConnectionResetCounterTmpl)
  282. return a.PreLoad()
  283. }