diff options
| -rw-r--r-- | rumba/model.py | 27 | 
1 files changed, 24 insertions, 3 deletions
| diff --git a/rumba/model.py b/rumba/model.py index d1d9156..6d6ffee 100644 --- a/rumba/model.py +++ b/rumba/model.py @@ -116,12 +116,18 @@ class DIF:      def del_member(self, node):          self.members.remove(node) +    def get_ipcp_class(self): +        return IPCP +  # Shim over UDP  #  class ShimUDPDIF(DIF):      def __init__(self, name, members = None):          DIF.__init__(self, name, members) +    def get_ipcp_class(self): +        return ShimUDPIPCP +  # Shim over Ethernet  #  # @link_speed [int] Speed of the Ethernet network, in Mbps @@ -133,6 +139,9 @@ class ShimEthDIF(DIF):          if self.link_speed < 0:              raise ValueError("link_speed must be a non-negative number") +    def get_ipcp_class(self): +        return ShimEthIPCP +  # Normal DIF  #  # @policies [dict] Policies of the normal DIF @@ -281,7 +290,7 @@ class Node:          del self.dif_bindings[name]          self._validate() -# Class representing an IPC Process to be created in the experiment +# Base class representing an IPC Process to be created in the experiment  #  # @name [string]: IPCP name  # @node: Node where the IPCP gets created @@ -313,6 +322,16 @@ class IPCP:      def __neq__(self, other):          return not (self == other) +class ShimEthIPCP(IPCP): +    def __init__(self, name, node, dif, ifname = None): +        IPCP.__init__(self, name, node, dif) +        self.ifname = ifname + +class ShimUDPIPCP(IPCP): +    def __init__(self, name, node, dif): +        IPCP.__init__(self, name, node, dif) +        # TODO add IP and port +  # Base class for ARCFIRE experiments  #  # @name [string] Name of the experiment @@ -508,8 +527,10 @@ class Experiment:              for dif in self.dif_ordering:                  if dif not in node.difs:                      continue -                ipcp = IPCP(name = '%s.%s' % (dif.name, node.name), -                            node = node, dif = dif) + +                ipcp = dif.get_ipcp_class()( +                                name = '%s.%s' % (dif.name, node.name), +                                node = node, dif = dif)                  if dif in node.dif_registrations:                      for lower in node.dif_registrations[dif]: | 
