#hide from fastdot import * from fastcore.all import * from dataclasses import dataclass layers1 = ['conv','conv','lin'] layers2 = ['conv','lin'] block1,block2 = ['block1','block2'] conns = ((block1, block2), (block1, layers2[-1])) g = graph_items(seq_cluster(layers1, block1), seq_cluster(layers2, block2)) g.add_items(*object_connections(conns)) g g = Dot() c = Cluster('cl', fillcolor='pink') a1,a2,b = c.add_items('a', 'a', 'b') c.add_items(a1.connect(a2), a2.connect(b)) g.add_item(Node('Check tooltip', tooltip="I have a tooltip!")) g.add_item(c) g @dataclass(frozen=True) class Layer: name:str; n_filters:int=1 class Linear(Layer): pass class Conv2d(Layer): pass @dataclass(frozen=True) class Sequential: layers:list; name:str block1 = Sequential([Conv2d('conv', 5), Linear('lin', 3)], 'block1') block2 = Sequential([Conv2d('conv1', 8), Conv2d('conv2', 2), Linear('lin')], 'block2') node_defaults['fillcolor'] = lambda o: 'greenyellow' if isinstance(o,Linear) else 'pink' cluster_defaults['label'] = node_defaults['label'] = attrgetter('name') node_defaults['tooltip'] = str c1 = seq_cluster(block1.layers, block1) c2 = seq_cluster(block2.layers, block2) e1,e2 = c1.connect(c2),c1.connect(c2.last()) graph_items(c1,c2,e1,e2) conns = ( (block1, block2), (block1, block2.layers[-1]), ) g = graph_items(seq_cluster(block1.layers, block1), seq_cluster(block2.layers, block2)) object2graph(block1.layers[-1]) g.add_items(*[object2graph(a).connect(object2graph(b)) for a,b in conns]) g g = graph_items(seq_cluster(block1.layers, block1), seq_cluster(block2.layers, block2)) g.add_items(*object_connections(conns)) g